1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
33 using llvm::yaml::Input
;
34 using llvm::yaml::MappingTraits
;
35 using llvm::yaml::ScalarEnumerationTraits
;
36 using llvm::yaml::ScalarTraits
;
38 #define DEBUG_TYPE "linalg-ods-gen"
40 //===----------------------------------------------------------------------===//
41 // Mapping structs (correspond to data types in the YAML description).
42 // TODO: Since this is a schema/part of the contract, it should be moved to
44 //===----------------------------------------------------------------------===//
48 struct LinalgYAMLContext
{
49 MLIRContext
*mlirContext
;
52 struct LinalgOpMetadata
{
54 std::string cppClassName
;
55 std::optional
<std::string
> doc
;
56 SmallVector
<std::string
> implements
;
57 SmallVector
<std::string
> defines
;
60 struct SerializedAffineMap
{
61 AffineMapAttr affineMapAttr
;
63 AffineMap
affineMap() { return affineMapAttr
.getValue(); }
66 enum class LinalgOperandDefKind
{
76 struct LinalgOperandDef
{
78 LinalgOperandDefKind kind
;
79 std::optional
<std::string
> typeVar
;
80 std::optional
<SerializedAffineMap
> shapeMap
;
81 std::optional
<SerializedAffineMap
> indexAttrMap
;
82 std::optional
<SmallVector
<int64_t>> defaultIndices
;
83 std::optional
<std::string
> defaultFn
;
86 enum class LinalgIteratorTypeDef
{
91 struct LinalgIndexingMapsConfig
{
92 std::optional
<SmallVector
<SerializedAffineMap
>> staticIndexingMaps
;
95 struct ScalarExpression
;
97 enum class ScalarFnKind
{ Unary
, Binary
, Type
};
101 std::optional
<std::string
> fnName
;
102 std::optional
<std::string
> attrName
;
103 std::optional
<std::string
> typeVar
;
104 // NOTE: This must be of arity 1, but to break the self-referential cycle,
105 // we use a heap allocated vector.
106 std::vector
<ScalarExpression
> operands
;
109 struct ScalarExpression
{
110 std::optional
<std::string
> arg
;
111 std::optional
<std::string
> constant
;
112 std::optional
<int64_t> index
;
113 std::optional
<ScalarFn
> scalarFn
;
116 struct ScalarAssign
{
118 ScalarExpression value
;
121 struct LinalgStructuredOpConfig
{
122 SmallVector
<LinalgOperandDef
> args
;
123 LinalgIndexingMapsConfig indexingMaps
;
124 SmallVector
<LinalgIteratorTypeDef
> iteratorTypes
;
125 std::vector
<ScalarAssign
> assignments
;
128 struct LinalgOpConfig
{
129 std::optional
<LinalgOpMetadata
> metadata
;
130 std::optional
<LinalgStructuredOpConfig
> structuredOp
;
135 //===----------------------------------------------------------------------===//
137 //===----------------------------------------------------------------------===//
139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef
)
140 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap
)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef
)
142 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign
)
143 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression
)
144 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig
)
149 /// Top-level type containing op metadata and one of a concrete op type.
150 /// Currently, the only defined op type is `structured_op` (maps to
151 /// `LinalgStructuredOpConfig`).
153 struct MappingTraits
<LinalgOpConfig
> {
154 static void mapping(IO
&io
, LinalgOpConfig
&info
) {
155 io
.mapOptional("metadata", info
.metadata
);
156 io
.mapOptional("structured_op", info
.structuredOp
);
160 /// A structured op models (at most) a single contraction by modeling
161 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
162 /// outputs, or index attributes.
163 /// - List of indexing maps (see `LinalgIndexingMaps`).
164 /// - Iterator types (see `LinalgIteratorTypeDef`).
165 /// - List of scalar level assignment (see `ScalarAssign`).
167 struct MappingTraits
<LinalgStructuredOpConfig
> {
168 static void mapping(IO
&io
, LinalgStructuredOpConfig
&info
) {
169 io
.mapRequired("args", info
.args
);
170 io
.mapRequired("indexing_maps", info
.indexingMaps
);
171 io
.mapRequired("iterator_types", info
.iteratorTypes
);
172 io
.mapRequired("assignments", info
.assignments
);
176 /// Maps a named tensor, scalar or attribute argument to an operation,
178 /// - `name`: Must be unique within the operation.
179 /// - `usage`: How the argument is used (input, output, attribute, etc).
180 /// - `type_var`: The symbolic type variable that binds to the element or self
181 /// type of the tensor or scalar argument, respectively.
182 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of
183 /// the argument. Only tensor arguments have a `shape_map`. Each shape must
184 /// be normalized over the same list of symbols and have no dimension
186 /// - `index_attr_map`: An optional AffineMap from all op symbols to the
187 /// index attribute symbols. During op creation these symbols are replaced
188 /// by the corresponding `name` index attribue values. Only index attribute
189 /// arguments have an `index_attr_map`.
190 /// - `default_indices`: An optional default initialization for index
191 /// attribute arguments.
192 /// - `default_fn`: An optional default initialization for function attribute
195 struct MappingTraits
<LinalgOperandDef
> {
196 static void mapping(IO
&io
, LinalgOperandDef
&info
) {
197 io
.mapRequired("name", info
.name
);
198 io
.mapRequired("kind", info
.kind
);
199 io
.mapOptional("type_var", info
.typeVar
);
200 io
.mapOptional("shape_map", info
.shapeMap
);
201 io
.mapOptional("index_attr_map", info
.indexAttrMap
);
202 io
.mapOptional("default_indices", info
.defaultIndices
);
203 io
.mapOptional("default_fn", info
.defaultFn
);
207 /// Usage enum for a named argument.
209 struct ScalarEnumerationTraits
<LinalgOperandDefKind
> {
210 static void enumeration(IO
&io
, LinalgOperandDefKind
&value
) {
211 io
.enumCase(value
, "input_tensor", LinalgOperandDefKind::InputTensor
);
212 io
.enumCase(value
, "scalar", LinalgOperandDefKind::Scalar
);
213 io
.enumCase(value
, "output_tensor", LinalgOperandDefKind::OutputTensor
);
214 io
.enumCase(value
, "index_attr", LinalgOperandDefKind::IndexAttr
);
215 io
.enumCase(value
, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr
);
216 io
.enumCase(value
, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr
);
217 io
.enumCase(value
, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr
);
221 /// Iterator type enum.
223 struct ScalarEnumerationTraits
<LinalgIteratorTypeDef
> {
224 static void enumeration(IO
&io
, LinalgIteratorTypeDef
&value
) {
225 io
.enumCase(value
, "parallel", LinalgIteratorTypeDef::parallel
);
226 io
.enumCase(value
, "reduction", LinalgIteratorTypeDef::reduction
);
230 /// Metadata about the op (name, C++ name, and documentation).
232 struct MappingTraits
<LinalgOpMetadata
> {
233 static void mapping(IO
&io
, LinalgOpMetadata
&info
) {
234 io
.mapRequired("name", info
.name
);
235 io
.mapRequired("cpp_class_name", info
.cppClassName
);
236 io
.mapOptional("doc", info
.doc
);
237 io
.mapOptional("implements", info
.implements
);
238 io
.mapOptional("defines", info
.defines
);
242 /// How the ops indexing maps are produced. Must be one of:
243 /// - static_indexing_maps: A static list of AffineMaps, possibly with
244 /// some symbols that bind to attributes of the op. Each indexing map must
245 /// be normalized over the same list of dimensions, and its symbols must
246 /// match the symbols for argument shapes.
248 struct MappingTraits
<LinalgIndexingMapsConfig
> {
249 static void mapping(IO
&io
, LinalgIndexingMapsConfig
&info
) {
250 io
.mapOptional("static_indexing_maps", info
.staticIndexingMaps
);
254 /// Models an assignment to a named output.
255 /// - The `arg` name must match a named output.
256 /// - The `value` is a scalar expression for computing the value to
257 /// assign (see `ScalarExpression`).
259 struct MappingTraits
<ScalarAssign
> {
260 static void mapping(IO
&io
, ScalarAssign
&info
) {
261 io
.mapRequired("arg", info
.arg
);
262 io
.mapRequired("value", info
.value
);
266 /// A scalar expression (RHS of an assignment). Must be one of:
267 /// - `scalar_arg`: An operation argument.
268 /// - `scalar_const`: A constant definition.
269 /// - `scalar_index`: An iteration index.
270 /// - `scalar_fn`: A named function (see `ScalarFn`).
272 struct MappingTraits
<ScalarExpression
> {
273 static void mapping(IO
&io
, ScalarExpression
&info
) {
274 io
.mapOptional("scalar_arg", info
.arg
);
275 io
.mapOptional("scalar_const", info
.constant
);
276 io
.mapOptional("scalar_index", info
.index
);
277 io
.mapOptional("scalar_fn", info
.scalarFn
);
281 /// Scalar function kind enum.
283 struct ScalarEnumerationTraits
<ScalarFnKind
> {
284 static void enumeration(IO
&io
, ScalarFnKind
&value
) {
285 io
.enumCase(value
, "unary", ScalarFnKind::Unary
);
286 io
.enumCase(value
, "binary", ScalarFnKind::Binary
);
287 io
.enumCase(value
, "type", ScalarFnKind::Type
);
291 /// A scalar expression that evaluates a named function.
292 /// Functions are generally "math" level and type polymorphic. Builtin
293 /// functions include:
294 /// - `add(lhs, rhs)`
295 /// - `mul(lhs, rhs)`
297 struct MappingTraits
<ScalarFn
> {
298 static void mapping(IO
&io
, ScalarFn
&info
) {
299 io
.mapRequired("kind", info
.kind
);
300 io
.mapOptional("fn_name", info
.fnName
);
301 io
.mapOptional("attr_name", info
.attrName
);
302 io
.mapOptional("type_var", info
.typeVar
);
303 io
.mapRequired("operands", info
.operands
);
307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
310 struct ScalarTraits
<SerializedAffineMap
> {
311 static void output(const SerializedAffineMap
&value
, void *rawYamlContext
,
313 assert(value
.affineMapAttr
);
314 value
.affineMapAttr
.print(out
);
316 static StringRef
input(StringRef scalar
, void *rawYamlContext
,
317 SerializedAffineMap
&value
) {
318 assert(rawYamlContext
);
319 auto *yamlContext
= static_cast<LinalgYAMLContext
*>(rawYamlContext
);
320 if (auto attr
= dyn_cast_or_null
<AffineMapAttr
>(
321 mlir::parseAttribute(scalar
, yamlContext
->mlirContext
)))
322 value
.affineMapAttr
= attr
;
323 else if (!value
.affineMapAttr
|| !isa
<AffineMapAttr
>(value
.affineMapAttr
))
324 return "could not parse as an affine map attribute";
327 static QuotingType
mustQuote(StringRef
) { return QuotingType::None
; }
335 //===----------------------------------------------------------------------===//
336 // Generation utilities
337 //===----------------------------------------------------------------------===//
339 class GenerationContext
{
341 GenerationContext(MLIRContext
*context
, raw_ostream
*odsOut
,
342 raw_ostream
*defnOut
)
343 : context(context
), loc(UnknownLoc::get(context
)), odsOut(odsOut
),
346 MLIRContext
*getContext() { return context
; }
348 void setLoc(Location loc
) { this->loc
= loc
; }
349 Location
getLoc() { return loc
; }
351 bool shouldGenerateOds() { return odsOut
; }
352 bool shouldGenerateDefns() { return defnOut
; }
354 raw_ostream
&odss() {
355 assert(odsOut
&& "ODS stream not defined");
359 raw_ostream
&defns() {
360 assert(defnOut
&& "Definition stream not defined");
365 MLIRContext
*context
;
368 raw_ostream
*defnOut
;
373 static std::string
generateCppExpression(SerializedAffineMap self
,
374 StringRef contextName
) {
375 std::string printedStr
;
376 llvm::raw_string_ostream
printedSs(printedStr
);
377 self
.affineMapAttr
.print(printedSs
);
380 static const char exprFormat
[] =
381 R
"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
382 return llvm::formatv(exprFormat
, printedStr
, contextName
);
385 template <typename Container
>
386 static std::string
interleaveToString(Container
&container
,
387 StringRef separator
) {
389 llvm::raw_string_ostream
ss(result
);
390 llvm::interleave(container
, ss
, separator
);
395 static std::optional
<int>
396 findTensorDefArgIndex(StringRef name
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
397 for (const auto &it
: llvm::enumerate(args
)) {
398 if (it
.value().name
== name
)
404 // Try to map the TypeVar to a predefined or an argument type.
405 static std::optional
<std::string
>
406 findTypeValue(StringRef typeVar
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
407 // Handle all predefined types.
408 if (typeVar
== "I32")
409 return std::string("helper.getIntegerType(32)");
410 if (typeVar
== "I64")
411 return std::string("helper.getIntegerType(64)");
412 if (typeVar
== "F32")
413 return std::string("helper.getFloat32Type()");
414 if (typeVar
== "F64")
415 return std::string("helper.getFloat64Type()");
417 // Search all argument types.
418 for (const auto &it
: llvm::enumerate(args
)) {
419 if (it
.value().kind
!= LinalgOperandDefKind::InputTensor
&&
420 it
.value().kind
!= LinalgOperandDefKind::Scalar
&&
421 it
.value().kind
!= LinalgOperandDefKind::OutputTensor
)
423 if (*it
.value().typeVar
== typeVar
)
424 return llvm::formatv("block.getArgument({0}).getType()", it
.index())
431 static ScalarAssign
*findAssignment(StringRef name
,
432 std::vector
<ScalarAssign
> &assignments
) {
433 for (auto &assign
: assignments
) {
434 if (assign
.arg
== name
)
440 // Return true if the operand is a function attribute.
441 static bool isFunctionAttribute(LinalgOperandDefKind kind
) {
442 return kind
== LinalgOperandDefKind::UnaryFnAttr
||
443 kind
== LinalgOperandDefKind::BinaryFnAttr
||
444 kind
== LinalgOperandDefKind::TypeFnAttr
;
447 // Return true if the operand is an attribute.
448 static bool isAttribute(LinalgOperandDefKind kind
) {
449 return kind
== LinalgOperandDefKind::IndexAttr
|| isFunctionAttribute(kind
);
452 // Get the enum name for the given operand kind.
453 std::string
convertOperandKindToEnumName(LinalgOperandDefKind kind
) {
455 case LinalgOperandDefKind::UnaryFnAttr
:
456 return std::string("UnaryFn");
457 case LinalgOperandDefKind::BinaryFnAttr
:
458 return std::string("BinaryFn");
459 case LinalgOperandDefKind::TypeFnAttr
:
460 return std::string("TypeFn");
464 llvm_unreachable("unsupported function attribute kind");
467 // Get the enum name for the given function kind.
468 std::string
convertFunctionKindToEnumName(ScalarFnKind kind
) {
470 case ScalarFnKind::Unary
:
471 return std::string("UnaryFn");
472 case ScalarFnKind::Binary
:
473 return std::string("BinaryFn");
474 case ScalarFnKind::Type
:
475 return std::string("TypeFn");
477 llvm_unreachable("unsupported function kind");
480 //===----------------------------------------------------------------------===//
482 //===----------------------------------------------------------------------===//
484 // A single line banner format. Parameters:
485 // {0}: Single line comment
486 static const char bannerFormat
[] = R
"FMT(
487 //===----------------------------------------------------------------------===//
489 //===----------------------------------------------------------------------===//
492 //===----------------------------------------------------------------------===//
493 // Named generic op generation.
494 // These ops map at most a single contraction that complies with the limitations
495 // of a linalg.generic.
496 //===----------------------------------------------------------------------===//
498 // Template for Linalg named ops' ODS definitions. Parameters:
499 // {0}: ODS/C++ op name
500 // {1}: assembly op mnemonic
501 // {2}: op interface list
502 // {3}: documentation (summary + description)
503 // {4}: op attribute list
504 // {5}: builder methods taking standalone attribute parameters
505 // {6}: additional method defintions
506 // {7}: additional methods for attributes used by indexing maps
507 static const char structuredOpOdsHeaderFormat
[] = R
"FMT(
508 //===----------------------------------------------------------------------===//
509 // Op definition for {0}
510 //===----------------------------------------------------------------------===//
512 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
513 /*extraInterfaces=*/[{2}])> {
516 Variadic<AnyType>:$inputs,
517 Variadic<AnyShaped>:$outputs{4}
519 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
520 let regions = (region AnyRegion:$region);
522 let skipDefaultBuilders = 1;
525 (ins "ValueRange
":$inputs, "ValueRange
":$outputs,
526 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
528 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
529 attributes, {0}::getRegionBuilder());
532 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
533 "ValueRange
":$outputs,
534 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
536 buildStructuredOp($_builder, $_state, resultTensorTypes,
537 inputs, outputs, attributes, {0}::getRegionBuilder());
540 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$operands,
541 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
543 $_state.addOperands(operands);
544 $_state.addAttributes(attributes);
545 $_state.addTypes(resultTensorTypes);
546 (void)$_state.addRegion();
550 let hasCustomAssemblyFormat = 1;
554 let extraClassDeclaration = structuredOpsBaseDecls # [{{
556 SmallVector<utils::IteratorType> getIteratorTypesArray();
557 ArrayAttr getIndexingMaps();
558 static void regionBuilder(ImplicitLocOpBuilder &b,
559 Block &block, ArrayRef<NamedAttribute> attrs);
560 static std::function<void(ImplicitLocOpBuilder &,
561 Block &, ArrayRef<NamedAttribute>)>
562 getRegionBuilder() {{
563 return regionBuilder;
566 ::mlir::MutableOperandRange getDpsInitsMutable() {{
567 return getOutputsMutable();
571 static unsigned getNumRegionArgs();
572 std::string getLibraryCallName();
578 // Builder method taking attribute parameters. Parameters:
580 // {1}: Comma interleaved attribute parameters
581 // {2}: Attribute initialization
582 static const char structuredOpBuilderFormat
[] = R
"FMT(
584 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
585 "ValueRange
":$outputs, {1},
586 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
589 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
590 attributes, {0}::getRegionBuilder());
594 // The getIteratorTypesArray() method for structured ops. Parameters:
596 // {1}: Comma interleaved iterator type names.
597 static const char structuredOpIteratorTypesFormat
[] =
599 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
600 return SmallVector<utils::IteratorType>{{ {1} };
604 // The getIteratorTypesArray() method for rank polymorphic structured ops.
607 static const char rankPolyStructuredOpIteratorTypesFormat
[] =
609 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
610 int64_t rank = getRank(getDpsInitOperand(0));
611 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
615 // The indexing_maps() method for structured ops. Parameters:
617 // {1}: Comma-separated list of dimension variable names.
619 static const char structuredOpIndexingMapsFormat
[] = R
"FMT(
620 ArrayAttr {0}::getIndexingMaps() {{
621 static const char memoizeAttr[] = "linalg
.memoized_indexing_maps
";
622 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
626 MLIRContext *context = getContext();
627 auto symbolBindings = getSymbolBindings(*this);
628 SmallVector<AffineMap> maps;
630 cached = Builder(context).getAffineMapArrayAttr(maps);
631 getOperation()->setAttr(memoizeAttr, cached);
636 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
638 static const char rankPolyStructuredOpIndexingMapsFormat
[] = R
"FMT(
639 ArrayAttr {0}::getIndexingMaps() {{
640 MLIRContext *context = getContext();
641 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
642 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
643 getNumParallelLoops(), context);
644 SmallVector<AffineMap> indexingMaps;
645 for (OpOperand &opOperand : getOperation()->getOpOperands())
646 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
647 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
651 // Implementations of fold and getEffects.
654 const char structuredOpFoldersFormat
[] = R
"FMT(
655 LogicalResult {0}::fold(FoldAdaptor,
656 SmallVectorImpl<OpFoldResult> &) {{
657 return memref::foldMemRefCast(*this);
659 void {0}::getEffects(SmallVectorImpl<
660 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661 if (hasPureTensorSemantics()) return;
662 getGenericEffectsImpl(effects,
663 getOperation()->getResults(), getDpsInputs(), getDpsInits());
667 // Implementation of parse/print.
670 static const char structuredOpParserFormat
[] = R
"FMT(
671 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
672 return ::parseNamedStructuredOp(parser, result,
673 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
675 void {0}::print(OpAsmPrinter &p) {{
676 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
680 static LogicalResult
generateNamedGenericOpOds(LinalgOpConfig
&opConfig
,
681 GenerationContext
&genContext
) {
682 if (!genContext
.shouldGenerateOds())
685 raw_ostream
&os
= genContext
.odss();
687 std::string interfaceNameList
;
688 std::string attrList
;
689 std::string attrMethods
;
690 std::string attrBuilder
;
693 if (opConfig
.metadata
->doc
) {
694 static const char structuredOpDocFmt
[] = R
"FMT(
695 let summary = [{{{0}}];
696 let description = [{{{1}}];
698 StringRef summary
, description
;
699 std::tie(summary
, description
) =
700 StringRef(*opConfig
.metadata
->doc
).trim().split("\n\n");
702 doc
= llvm::formatv(structuredOpDocFmt
, summary
.trim(), description
.trim());
705 interfaceNameList
= interleaveToString(opConfig
.metadata
->implements
, ", ");
707 std::string definitionList
;
708 for (const std::string
&definition
: opConfig
.metadata
->defines
) {
709 static const char definitionFmt
[] = "let {0} = 1;\n";
710 definitionList
.append(llvm::formatv(definitionFmt
, definition
));
713 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
714 return isAttribute(arg
.kind
);
716 SmallVector
<std::string
> attrDefs
;
717 SmallVector
<std::string
> attrParams
;
718 SmallVector
<std::string
> attrStmts
;
719 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
720 static const char paramFmt
[] = "\"Attribute\":${0}";
721 static const char stmtFmt
[] = "$_state.addAttribute(\"{0}\", {0});";
722 // Add the type conversion attributes to the op definition and builders.
723 if (isFunctionAttribute(arg
.kind
)) {
724 assert(arg
.defaultFn
);
725 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
726 static const char typeFmt
[] = "{0}::{1}";
727 static const char defFmt
[] =
728 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
729 attrDefs
.push_back(llvm::formatv(
730 defFmt
, llvm::formatv("{0}Attr", enumName
),
731 llvm::formatv(typeFmt
, enumName
, arg
.defaultFn
), arg
.name
));
732 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
733 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
735 // Add the index attributes to the op definition and builders.
736 if (arg
.kind
== LinalgOperandDefKind::IndexAttr
) {
737 assert(arg
.indexAttrMap
.has_value());
738 assert(arg
.defaultIndices
.has_value());
739 size_t size
= arg
.indexAttrMap
->affineMap().getNumResults();
740 assert(arg
.defaultIndices
->size() == size
);
741 static const char typeFmt
[] = "RankedI64ElementsAttr<[{0}]>";
742 static const char defFmt
[] =
743 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
744 std::string defaultVals
;
745 llvm::raw_string_ostream
ss(defaultVals
);
747 *arg
.defaultIndices
, ss
,
748 [&](int64_t val
) { ss
<< "static_cast<int64_t>(" << val
<< ")"; },
750 attrDefs
.push_back(llvm::formatv(defFmt
, llvm::formatv(typeFmt
, size
),
751 ss
.str(), arg
.name
));
752 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
753 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
756 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
757 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
760 bool hasDynamicIndexingMaps();
761 LogicalResult verifyIndexingMapRequiredAttributes();
764 attrList
= ",\n" + llvm::join(attrDefs
, ",\n");
765 attrBuilder
= llvm::formatv(
766 structuredOpBuilderFormat
, opConfig
.metadata
->cppClassName
,
767 llvm::join(attrParams
, ", "), llvm::join(attrStmts
, "\n"));
770 os
<< llvm::formatv(structuredOpOdsHeaderFormat
,
771 opConfig
.metadata
->cppClassName
, opConfig
.metadata
->name
,
772 interfaceNameList
, doc
, attrList
, attrBuilder
,
773 definitionList
, attrMethods
);
779 generateNamedGenericOpDefns(LinalgOpConfig
&opConfig
,
780 GenerationContext
&genContext
) {
781 if (!genContext
.shouldGenerateDefns())
784 raw_ostream
&os
= genContext
.defns();
785 StringRef className
= opConfig
.metadata
->cppClassName
;
787 // Implementation banner.
788 std::string bannerComment
= llvm::formatv("Implementation of {0}", className
);
789 os
<< llvm::formatv(bannerFormat
, bannerComment
);
791 // Compute the number of scalar and tensor arguments.
793 llvm::count_if(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
794 return arg
.kind
== LinalgOperandDefKind::InputTensor
||
795 arg
.kind
== LinalgOperandDefKind::Scalar
||
796 arg
.kind
== LinalgOperandDefKind::OutputTensor
;
799 // An operation that accesses only scalars and scalar/rank zero tensors is
800 // rank polymorhpic. We implement rank polymorphism by generating different
801 // indexing maps and iterators that match the rank of the first output tensor.
802 // An operation is rank polymorphic if the iteration domain has rank zero.
803 bool isRankPolymorphic
= opConfig
.structuredOp
->iteratorTypes
.empty();
805 // Generate the iterator_types() method.
806 if (!isRankPolymorphic
) {
807 std::string iteratorsStr
;
808 llvm::raw_string_ostream
ss(iteratorsStr
);
809 llvm::interleaveComma(opConfig
.structuredOp
->iteratorTypes
, ss
,
810 [&](LinalgIteratorTypeDef it
) {
812 case LinalgIteratorTypeDef::parallel
:
813 ss
<< "utils::IteratorType::parallel";
815 case LinalgIteratorTypeDef::reduction
:
816 ss
<< "utils::IteratorType::reduction";
821 os
<< llvm::formatv(structuredOpIteratorTypesFormat
, className
,
824 os
<< llvm::formatv(rankPolyStructuredOpIteratorTypesFormat
, className
);
827 // Generating the getIndexingMaps() method.
828 if (auto &staticMaps
=
829 opConfig
.structuredOp
->indexingMaps
.staticIndexingMaps
) {
830 if (staticMaps
->empty())
831 return emitError(genContext
.getLoc()) << "op has no indexing maps";
832 if (!isRankPolymorphic
) {
833 AffineMap firstMap
= staticMaps
->front().affineMap();
837 // For each symbol, generate a declaration for it, either with an
838 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
840 // TODO: Possibly lift into a top-level method.
841 static const char structuredOpSymbolBindingsFormat
[] = R
"FMT(
842 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
843 MLIRContext *context = self.getContext();
844 SmallVector<AffineExpr> exprs;
850 unsigned symbolCount
= firstMap
.getNumSymbols();
851 SmallVector
<std::string
> symbolBindings
;
852 for (unsigned i
= 0; i
< symbolCount
; ++i
) {
853 symbolBindings
.push_back(llvm::formatv(
854 " exprs.push_back(getAffineSymbolExpr({0}, context));", i
));
857 // Access an index attribute. Parameters:
858 // {0}: Attribute name
859 // {1}: Symbol position
860 // {2}: Attribute index
861 static const char structuredOpAccessAttrFormat
[] = R
"FMT(
862 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
863 exprs.push_back(getAffineConstantExpr(cst{1}, context));
865 // Update all symbol bindings mapped to an attribute.
866 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
867 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
869 assert(arg
.indexAttrMap
);
870 for (auto [idx
, result
] :
871 llvm::enumerate(arg
.indexAttrMap
->affineMap().getResults())) {
872 if (auto symbol
= dyn_cast
<AffineSymbolExpr
>(result
)) {
873 std::string argName
= arg
.name
;
874 argName
[0] = toupper(argName
[0]);
875 symbolBindings
[symbol
.getPosition()] =
876 llvm::formatv(structuredOpAccessAttrFormat
, argName
,
877 symbol
.getPosition(), idx
);
882 std::string symbolBindingsStr
;
883 llvm::raw_string_ostream
symbolBindingsSs(symbolBindingsStr
);
884 llvm::interleave(symbolBindings
, symbolBindingsSs
, "\n");
885 symbolBindingsSs
.flush();
887 os
<< llvm::formatv(structuredOpSymbolBindingsFormat
, className
,
893 unsigned dimCount
= firstMap
.getNumDims();
895 // Generate a comma-separated list of dim identifiers to be passed to
896 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
897 // order to the proper AffineDimExpr.
898 // This results in vars in scope like: d0, d1, d2...
899 SmallVector
<unsigned> dimIndices
;
900 for (unsigned i
= 0; i
< dimCount
; ++i
)
901 dimIndices
.push_back(i
);
902 std::string dimIdentsStr
;
903 llvm::raw_string_ostream
dimIdentsSs(dimIdentsStr
);
904 llvm::interleaveComma(dimIndices
, dimIdentsSs
,
905 [&](unsigned i
) { dimIdentsSs
<< "d" << i
; });
908 // Statements to add and simplify each affine map.
909 SmallVector
<std::string
> stmts
;
910 for (auto &indexingMap
: *staticMaps
) {
911 // TODO: Assert that dim and symbol count match the first.
913 llvm::formatv("maps.push_back({0});",
914 generateCppExpression(indexingMap
, "context")));
915 stmts
.push_back(llvm::formatv(
917 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
918 "symbolBindings, {0}, 0));",
922 // TODO: This needs to be memoized and/or converted to non-parser based
923 // C++ codegen prior to real use.
924 os
<< llvm::formatv(structuredOpIndexingMapsFormat
, className
,
925 dimIdentsStr
, interleaveToString(stmts
, "\n "));
928 os
<< llvm::formatv(rankPolyStructuredOpIndexingMapsFormat
, className
);
931 return emitError(genContext
.getLoc())
932 << "generating code for non static indexing maps not currently "
936 // getNumRegionArgs()
938 // Generates a getNumRegionArgs() method. Parameters:
940 // {1}: Number of region args
941 static const char structuredOpGetNumRegionArgsFormat
[] = R
"FMT(
942 unsigned {0}::getNumRegionArgs() {{ return {1}; }
944 os
<< llvm::formatv(structuredOpGetNumRegionArgsFormat
, className
,
948 // getLibraryCallName()
950 // Generates a getLibraryCallName method. Parameters:
952 static const char structuredOpGetLibraryCallFormat
[] = R
"FMT(
953 std::string {0}::getLibraryCallName() {{
954 return generateLibraryCallName(getOperation());
957 os
<< llvm::formatv(structuredOpGetLibraryCallFormat
, className
);
960 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
961 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
962 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
964 std::vector
<std::string
> attrVerifications
;
965 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
966 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
968 assert(arg
.indexAttrMap
);
969 // Verify index attribute. Paramters:
970 // {0}: Attribute name
971 // {1}: Attribute size
972 static const char attrFmt
[] = R
"FMT(
973 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
974 if (!attr.getType().getElementType().isInteger(64))
975 return op->emitError("incorrect element type
for index attribute
'{0}'");
976 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
977 return op->emitError("incorrect shape
for index attribute
'{0}'");
980 attrVerifications
.push_back(llvm::formatv(
981 attrFmt
, arg
.name
, arg
.indexAttrMap
->affineMap().getNumResults()));
984 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
986 // {1}: Attribute verification
987 static const char structuredOpVerifyIndexingMapRequiredAttributes
[] = R
"FMT(
988 bool {0}::hasDynamicIndexingMaps() {{ return true; }
989 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
990 Operation *op = getOperation();
995 os
<< llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes
,
996 className
, llvm::join(attrVerifications
, "\n"));
1001 // Generates a regionBuilder method. Parameters.
1003 // {1}: Number of args
1006 static const char structuredOpRegionBuilderFormat
[] = R
"FMT(
1007 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1008 Block &block, ArrayRef<NamedAttribute> attrs) {{
1009 assert({1} > 0 && block.getNumArguments() == {1} &&
1010 "{0} regionBuilder expects
{1} (>=0) args
");
1011 RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
1012 SmallVector<Value> yields;
1015 helper.yieldOutputs(yields);
1018 auto &args
= opConfig
.structuredOp
->args
;
1019 auto &assignments
= opConfig
.structuredOp
->assignments
;
1020 size_t generatedAssignmentCount
= 0;
1021 int localCounter
= 0;
1022 SmallVector
<std::string
> attrs
;
1023 SmallVector
<std::string
> stmts
;
1024 for (LinalgOperandDef
&arg
: args
) {
1025 if (!isFunctionAttribute(arg
.kind
))
1027 // Obtain the type function attribute values. Parameters.
1029 // {1}: attribute name
1030 // {2}: default type function name
1031 static const char attrDef
[] = R
"FMT(
1032 {0} {1}Val = {0}::{2};
1033 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1034 return attr.getName() == "{1}"; });
1035 if ({1}Iter != attrs.end()) {{
1036 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1037 {1}Val = attr.getValue();
1040 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
1042 llvm::formatv(attrDef
, enumName
, arg
.name
, arg
.defaultFn
));
1044 for (LinalgOperandDef
&arg
: args
) {
1045 if (arg
.kind
!= LinalgOperandDefKind::OutputTensor
)
1048 // Find the assignment that correlates with the argument.
1049 ScalarAssign
*assignment
= findAssignment(arg
.name
, assignments
);
1051 return emitError(genContext
.getLoc())
1052 << "no assignment found for output argument " << arg
.name
;
1053 ++generatedAssignmentCount
;
1055 // Recursively generate the expression.
1056 std::function
<std::optional
<std::string
>(ScalarExpression
&)>
1057 generateExpression
=
1058 [&](ScalarExpression
&expression
) -> std::optional
<std::string
> {
1059 if (expression
.arg
) {
1060 // Argument reference.
1061 std::optional
<int> argIndex
=
1062 findTensorDefArgIndex(*expression
.arg
, args
);
1064 emitError(genContext
.getLoc())
1065 << "scalar argument not defined on the op: " << *expression
.arg
;
1066 return std::nullopt
;
1069 llvm::formatv("block.getArgument({0})", *argIndex
));
1071 if (expression
.constant
) {
1072 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1074 llvm::formatv(R
"FMT(Value {0} = helper.constant("{1}");)FMT",
1075 cppIdent
, expression
.constant
));
1078 if (expression
.index
) {
1079 // Access an iteration index.
1080 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1081 stmts
.push_back(llvm::formatv("Value {0} = helper.index({1});",
1082 cppIdent
, *expression
.index
));
1085 if (expression
.scalarFn
) {
1086 std::string enumName
=
1087 convertFunctionKindToEnumName(expression
.scalarFn
->kind
);
1089 // Get the function or attribute name.
1090 assert(expression
.scalarFn
->fnName
|| expression
.scalarFn
->attrName
);
1091 std::string funcType
;
1092 if (expression
.scalarFn
->fnName
) {
1093 funcType
= llvm::formatv("{0}::{1}", enumName
,
1094 *expression
.scalarFn
->fnName
);
1096 if (expression
.scalarFn
->attrName
) {
1097 if (llvm::none_of(args
, [&](LinalgOperandDef
&arg
) {
1098 return isFunctionAttribute(arg
.kind
) &&
1099 arg
.name
== *expression
.scalarFn
->attrName
;
1101 emitError(genContext
.getLoc()) << "missing function attribute "
1102 << *expression
.scalarFn
->attrName
;
1104 funcType
= llvm::formatv("{0}Val", *expression
.scalarFn
->attrName
);
1106 assert(!funcType
.empty());
1108 // Add the optional type parameter to the operands.
1109 SmallVector
<std::string
> operandCppValues
;
1110 if (expression
.scalarFn
->kind
== ScalarFnKind::Type
) {
1111 assert(expression
.scalarFn
->typeVar
.has_value());
1112 std::optional
<std::string
> typeCppValue
=
1113 findTypeValue(*expression
.scalarFn
->typeVar
, args
);
1114 if (!typeCppValue
) {
1115 emitError(genContext
.getLoc())
1116 << "type variable " << *expression
.scalarFn
->typeVar
1117 << ", used in a type conversion, must map to a predefined or "
1118 << "an argument type but it does not";
1119 return std::nullopt
;
1121 operandCppValues
.push_back(*typeCppValue
);
1124 // Collect the scalar operands.
1125 for (ScalarExpression
&operand
: expression
.scalarFn
->operands
) {
1126 auto operandCppValue
= generateExpression(operand
);
1127 if (!operandCppValue
)
1128 return std::nullopt
;
1129 operandCppValues
.push_back(*operandCppValue
);
1132 // Call the function builder.
1133 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1134 stmts
.push_back(llvm::formatv(
1135 "Value {0} = helper.build{1}({2}, {3});", cppIdent
, enumName
,
1136 funcType
, interleaveToString(operandCppValues
, ", ")));
1139 emitError(genContext
.getLoc()) << "unknown ScalarExpression type";
1140 return std::nullopt
;
1142 std::optional
<std::string
> cppValue
=
1143 generateExpression(assignment
->value
);
1146 stmts
.push_back(llvm::formatv("yields.push_back({0});", *cppValue
));
1149 if (generatedAssignmentCount
!= assignments
.size())
1150 return emitError(genContext
.getLoc())
1151 << "mismatched number of assignments vs output arguments";
1153 os
<< llvm::formatv(structuredOpRegionBuilderFormat
, className
, numOfArgs
,
1154 interleaveToString(attrs
, "\n "),
1155 interleaveToString(stmts
, "\n "));
1158 // Parser and printer.
1159 os
<< llvm::formatv(structuredOpParserFormat
, className
);
1161 // Canonicalizers and folders.
1162 os
<< llvm::formatv(structuredOpFoldersFormat
, className
);
1167 static LogicalResult
generateOp(LinalgOpConfig
&opConfig
,
1168 GenerationContext
&genContext
) {
1169 // Switch on op type being generated.
1170 if (opConfig
.structuredOp
) {
1172 succeeded(generateNamedGenericOpOds(opConfig
, genContext
)) &&
1173 succeeded(generateNamedGenericOpDefns(opConfig
, genContext
)));
1175 return emitError(genContext
.getLoc()) << "unsupported operation type";
1178 //===----------------------------------------------------------------------===//
1179 // Command line options and main
1180 //===----------------------------------------------------------------------===//
1182 static llvm::cl::opt
<std::string
>
1183 inputFilename(llvm::cl::Positional
, llvm::cl::desc("<input file>"),
1184 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1186 static llvm::cl::opt
<std::string
>
1187 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1188 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1190 static llvm::cl::opt
<std::string
>
1191 outputCppImplFilename("o-impl",
1192 llvm::cl::desc("C++ implementation file name"),
1193 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1195 int main(int argc
, char **argv
) {
1196 llvm::cl::ParseCommandLineOptions(argc
, argv
, "Linalg ODS Gen from YAML");
1198 // Set up the input file.
1199 std::string errorMessage
;
1200 std::unique_ptr
<llvm::MemoryBuffer
> file
=
1201 mlir::openInputFile(inputFilename
, &errorMessage
);
1203 llvm::errs() << errorMessage
<< "\n";
1207 MLIRContext mlirContext
;
1208 LinalgYAMLContext yamlContext
{&mlirContext
};
1210 std::vector
<LinalgOpConfig
> opConfigs
;
1213 Input
yin(file
->getBuffer(), &yamlContext
);
1219 // Open output files.
1220 std::unique_ptr
<llvm::ToolOutputFile
> outputOdsDecl
;
1221 if (!outputOdsDeclFilename
.empty()) {
1222 outputOdsDecl
= openOutputFile(outputOdsDeclFilename
, &errorMessage
);
1223 if (!outputOdsDecl
) {
1224 llvm::errs() << errorMessage
<< "\n";
1229 std::unique_ptr
<llvm::ToolOutputFile
> outputCppImpl
;
1230 if (!outputCppImplFilename
.empty()) {
1231 outputCppImpl
= openOutputFile(outputCppImplFilename
, &errorMessage
);
1232 if (!outputCppImpl
) {
1233 llvm::errs() << errorMessage
<< "\n";
1238 if (!outputOdsDecl
&& !outputCppImpl
) {
1239 llvm::errs() << "error: No output files specified\n";
1244 GenerationContext
genContext(&mlirContext
,
1245 outputOdsDecl
? &outputOdsDecl
->os() : nullptr,
1246 outputCppImpl
? &outputCppImpl
->os() : nullptr);
1248 for (auto &opConfig
: opConfigs
) {
1249 if (!opConfig
.metadata
) {
1250 emitError(genContext
.getLoc())
1251 << "missing operation metadata on subsequent op";
1255 genContext
.setLoc(NameLoc::get(
1256 StringAttr::get(&mlirContext
, opConfig
.metadata
->cppClassName
)));
1257 if (failed(generateOp(opConfig
, genContext
))) {
1263 outputOdsDecl
->keep();
1265 outputCppImpl
->keep();