1 //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements an ODS (and C++) generator from a YAML form
10 // derived from the mathematical expression of linalg named ops. Typically a
11 // math oriented DSL will be used to export the essential representation to
12 // this form, and maintaining the SOT at the math level (versus recreating it
13 // in MLIR) is deemed to have systemic value.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/AsmParser/AsmParser.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/ToolOutputFile.h"
28 #include "llvm/Support/YAMLTraits.h"
33 using llvm::yaml::Input
;
34 using llvm::yaml::MappingTraits
;
35 using llvm::yaml::ScalarEnumerationTraits
;
36 using llvm::yaml::ScalarTraits
;
38 #define DEBUG_TYPE "linalg-ods-gen"
40 //===----------------------------------------------------------------------===//
41 // Mapping structs (correspond to data types in the YAML description).
42 // TODO: Since this is a schema/part of the contract, it should be moved to
44 //===----------------------------------------------------------------------===//
48 struct LinalgYAMLContext
{
49 MLIRContext
*mlirContext
;
52 struct LinalgOpMetadata
{
54 std::string cppClassName
;
55 std::optional
<std::string
> doc
;
56 SmallVector
<std::string
> implements
;
57 SmallVector
<std::string
> defines
;
60 struct SerializedAffineMap
{
61 AffineMapAttr affineMapAttr
;
63 AffineMap
affineMap() { return affineMapAttr
.getValue(); }
66 enum class LinalgOperandDefKind
{
76 struct LinalgOperandDef
{
78 LinalgOperandDefKind kind
;
79 std::optional
<std::string
> typeVar
;
80 std::optional
<SerializedAffineMap
> shapeMap
;
81 std::optional
<SerializedAffineMap
> indexAttrMap
;
82 std::optional
<SmallVector
<int64_t>> defaultIndices
;
83 std::optional
<std::string
> defaultFn
;
86 enum class LinalgIteratorTypeDef
{
91 struct LinalgIndexingMapsConfig
{
92 std::optional
<SmallVector
<SerializedAffineMap
>> staticIndexingMaps
;
95 struct ScalarExpression
;
97 enum class ScalarFnKind
{ Unary
, Binary
, Type
};
101 std::optional
<std::string
> fnName
;
102 std::optional
<std::string
> attrName
;
103 std::optional
<std::string
> typeVar
;
104 // NOTE: This must be of arity 1, but to break the self-referential cycle,
105 // we use a heap allocated vector.
106 std::vector
<ScalarExpression
> operands
;
109 struct ScalarExpression
{
110 std::optional
<std::string
> arg
;
111 std::optional
<std::string
> constant
;
112 std::optional
<int64_t> index
;
113 std::optional
<ScalarFn
> scalarFn
;
116 struct ScalarAssign
{
118 ScalarExpression value
;
121 struct LinalgStructuredOpConfig
{
122 SmallVector
<LinalgOperandDef
> args
;
123 LinalgIndexingMapsConfig indexingMaps
;
124 SmallVector
<LinalgIteratorTypeDef
> iteratorTypes
;
125 std::vector
<ScalarAssign
> assignments
;
128 struct LinalgOpConfig
{
129 std::optional
<LinalgOpMetadata
> metadata
;
130 std::optional
<LinalgStructuredOpConfig
> structuredOp
;
135 //===----------------------------------------------------------------------===//
137 //===----------------------------------------------------------------------===//
139 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef
)
140 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap
)
141 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef
)
142 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign
)
143 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression
)
144 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig
)
149 /// Top-level type containing op metadata and one of a concrete op type.
150 /// Currently, the only defined op type is `structured_op` (maps to
151 /// `LinalgStructuredOpConfig`).
153 struct MappingTraits
<LinalgOpConfig
> {
154 static void mapping(IO
&io
, LinalgOpConfig
&info
) {
155 io
.mapOptional("metadata", info
.metadata
);
156 io
.mapOptional("structured_op", info
.structuredOp
);
160 /// A structured op models (at most) a single contraction by modeling
161 /// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
162 /// outputs, or index attributes.
163 /// - List of indexing maps (see `LinalgIndexingMaps`).
164 /// - Iterator types (see `LinalgIteratorTypeDef`).
165 /// - List of scalar level assignment (see `ScalarAssign`).
167 struct MappingTraits
<LinalgStructuredOpConfig
> {
168 static void mapping(IO
&io
, LinalgStructuredOpConfig
&info
) {
169 io
.mapRequired("args", info
.args
);
170 io
.mapRequired("indexing_maps", info
.indexingMaps
);
171 io
.mapRequired("iterator_types", info
.iteratorTypes
);
172 io
.mapRequired("assignments", info
.assignments
);
176 /// Maps a named tensor, scalar or attribute argument to an operation,
178 /// - `name`: Must be unique within the operation.
179 /// - `usage`: How the argument is used (input, output, attribute, etc).
180 /// - `type_var`: The symbolic type variable that binds to the element or self
181 /// type of the tensor or scalar argument, respectively.
182 /// - `shape_map`: An optional AffineMap from all op symbols to the shape of
183 /// the argument. Only tensor arguments have a `shape_map`. Each shape must
184 /// be normalized over the same list of symbols and have no dimension
186 /// - `index_attr_map`: An optional AffineMap from all op symbols to the
187 /// index attribute symbols. During op creation these symbols are replaced
188 /// by the corresponding `name` index attribue values. Only index attribute
189 /// arguments have an `index_attr_map`.
190 /// - `default_indices`: An optional default initialization for index
191 /// attribute arguments.
192 /// - `default_fn`: An optional default initialization for function attribute
195 struct MappingTraits
<LinalgOperandDef
> {
196 static void mapping(IO
&io
, LinalgOperandDef
&info
) {
197 io
.mapRequired("name", info
.name
);
198 io
.mapRequired("kind", info
.kind
);
199 io
.mapOptional("type_var", info
.typeVar
);
200 io
.mapOptional("shape_map", info
.shapeMap
);
201 io
.mapOptional("index_attr_map", info
.indexAttrMap
);
202 io
.mapOptional("default_indices", info
.defaultIndices
);
203 io
.mapOptional("default_fn", info
.defaultFn
);
207 /// Usage enum for a named argument.
209 struct ScalarEnumerationTraits
<LinalgOperandDefKind
> {
210 static void enumeration(IO
&io
, LinalgOperandDefKind
&value
) {
211 io
.enumCase(value
, "input_tensor", LinalgOperandDefKind::InputTensor
);
212 io
.enumCase(value
, "scalar", LinalgOperandDefKind::Scalar
);
213 io
.enumCase(value
, "output_tensor", LinalgOperandDefKind::OutputTensor
);
214 io
.enumCase(value
, "index_attr", LinalgOperandDefKind::IndexAttr
);
215 io
.enumCase(value
, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr
);
216 io
.enumCase(value
, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr
);
217 io
.enumCase(value
, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr
);
221 /// Iterator type enum.
223 struct ScalarEnumerationTraits
<LinalgIteratorTypeDef
> {
224 static void enumeration(IO
&io
, LinalgIteratorTypeDef
&value
) {
225 io
.enumCase(value
, "parallel", LinalgIteratorTypeDef::parallel
);
226 io
.enumCase(value
, "reduction", LinalgIteratorTypeDef::reduction
);
230 /// Metadata about the op (name, C++ name, and documentation).
232 struct MappingTraits
<LinalgOpMetadata
> {
233 static void mapping(IO
&io
, LinalgOpMetadata
&info
) {
234 io
.mapRequired("name", info
.name
);
235 io
.mapRequired("cpp_class_name", info
.cppClassName
);
236 io
.mapOptional("doc", info
.doc
);
237 io
.mapOptional("implements", info
.implements
);
238 io
.mapOptional("defines", info
.defines
);
242 /// How the ops indexing maps are produced. Must be one of:
243 /// - static_indexing_maps: A static list of AffineMaps, possibly with
244 /// some symbols that bind to attributes of the op. Each indexing map must
245 /// be normalized over the same list of dimensions, and its symbols must
246 /// match the symbols for argument shapes.
248 struct MappingTraits
<LinalgIndexingMapsConfig
> {
249 static void mapping(IO
&io
, LinalgIndexingMapsConfig
&info
) {
250 io
.mapOptional("static_indexing_maps", info
.staticIndexingMaps
);
254 /// Models an assignment to a named output.
255 /// - The `arg` name must match a named output.
256 /// - The `value` is a scalar expression for computing the value to
257 /// assign (see `ScalarExpression`).
259 struct MappingTraits
<ScalarAssign
> {
260 static void mapping(IO
&io
, ScalarAssign
&info
) {
261 io
.mapRequired("arg", info
.arg
);
262 io
.mapRequired("value", info
.value
);
266 /// A scalar expression (RHS of an assignment). Must be one of:
267 /// - `scalar_arg`: An operation argument.
268 /// - `scalar_const`: A constant definition.
269 /// - `scalar_index`: An iteration index.
270 /// - `scalar_fn`: A named function (see `ScalarFn`).
272 struct MappingTraits
<ScalarExpression
> {
273 static void mapping(IO
&io
, ScalarExpression
&info
) {
274 io
.mapOptional("scalar_arg", info
.arg
);
275 io
.mapOptional("scalar_const", info
.constant
);
276 io
.mapOptional("scalar_index", info
.index
);
277 io
.mapOptional("scalar_fn", info
.scalarFn
);
281 /// Scalar function kind enum.
283 struct ScalarEnumerationTraits
<ScalarFnKind
> {
284 static void enumeration(IO
&io
, ScalarFnKind
&value
) {
285 io
.enumCase(value
, "unary", ScalarFnKind::Unary
);
286 io
.enumCase(value
, "binary", ScalarFnKind::Binary
);
287 io
.enumCase(value
, "type", ScalarFnKind::Type
);
291 /// A scalar expression that evaluates a named function.
292 /// Functions are generally "math" level and type polymorphic. Builtin
293 /// functions include:
294 /// - `add(lhs, rhs)`
295 /// - `mul(lhs, rhs)`
297 struct MappingTraits
<ScalarFn
> {
298 static void mapping(IO
&io
, ScalarFn
&info
) {
299 io
.mapRequired("kind", info
.kind
);
300 io
.mapOptional("fn_name", info
.fnName
);
301 io
.mapOptional("attr_name", info
.attrName
);
302 io
.mapOptional("type_var", info
.typeVar
);
303 io
.mapRequired("operands", info
.operands
);
307 /// Helper mapping which accesses an AffineMapAttr as a serialized string of
310 struct ScalarTraits
<SerializedAffineMap
> {
311 static void output(const SerializedAffineMap
&value
, void *rawYamlContext
,
313 assert(value
.affineMapAttr
);
314 value
.affineMapAttr
.print(out
);
316 static StringRef
input(StringRef scalar
, void *rawYamlContext
,
317 SerializedAffineMap
&value
) {
318 assert(rawYamlContext
);
319 auto *yamlContext
= static_cast<LinalgYAMLContext
*>(rawYamlContext
);
320 if (auto attr
= dyn_cast_or_null
<AffineMapAttr
>(
321 mlir::parseAttribute(scalar
, yamlContext
->mlirContext
)))
322 value
.affineMapAttr
= attr
;
323 else if (!value
.affineMapAttr
|| !isa
<AffineMapAttr
>(value
.affineMapAttr
))
324 return "could not parse as an affine map attribute";
327 static QuotingType
mustQuote(StringRef
) { return QuotingType::None
; }
335 //===----------------------------------------------------------------------===//
336 // Generation utilities
337 //===----------------------------------------------------------------------===//
339 class GenerationContext
{
341 GenerationContext(MLIRContext
*context
, raw_ostream
*odsOut
,
342 raw_ostream
*defnOut
)
343 : context(context
), loc(UnknownLoc::get(context
)), odsOut(odsOut
),
346 MLIRContext
*getContext() { return context
; }
348 void setLoc(Location loc
) { this->loc
= loc
; }
349 Location
getLoc() { return loc
; }
351 bool shouldGenerateOds() { return odsOut
; }
352 bool shouldGenerateDefns() { return defnOut
; }
354 raw_ostream
&odss() {
355 assert(odsOut
&& "ODS stream not defined");
359 raw_ostream
&defns() {
360 assert(defnOut
&& "Definition stream not defined");
365 MLIRContext
*context
;
368 raw_ostream
*defnOut
;
373 static std::string
generateCppExpression(SerializedAffineMap self
,
374 StringRef contextName
) {
375 std::string printedStr
;
376 llvm::raw_string_ostream
printedSs(printedStr
);
377 self
.affineMapAttr
.print(printedSs
);
380 static const char exprFormat
[] =
381 R
"FMT(mlir::parseAttribute("{0}", {1}).cast<AffineMapAttr>().getValue())FMT";
382 return llvm::formatv(exprFormat
, printedStr
, contextName
);
385 template <typename Container
>
386 static std::string
interleaveToString(Container
&container
,
387 StringRef separator
) {
389 llvm::raw_string_ostream
ss(result
);
390 llvm::interleave(container
, ss
, separator
);
395 static std::optional
<int>
396 findTensorDefArgIndex(StringRef name
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
397 for (const auto &it
: llvm::enumerate(args
)) {
398 if (it
.value().name
== name
)
404 // Try to map the TypeVar to a predefined or an argument type.
405 static std::optional
<std::string
>
406 findTypeValue(StringRef typeVar
, SmallVectorImpl
<LinalgOperandDef
> &args
) {
407 // Handle all predefined types.
408 if (typeVar
== "I32")
409 return std::string("helper.getIntegerType(32)");
410 if (typeVar
== "I64")
411 return std::string("helper.getIntegerType(64)");
412 if (typeVar
== "F32")
413 return std::string("helper.getFloat32Type()");
414 if (typeVar
== "F64")
415 return std::string("helper.getFloat64Type()");
417 // Search all argument types.
418 for (const auto &it
: llvm::enumerate(args
)) {
419 if (it
.value().kind
!= LinalgOperandDefKind::InputTensor
&&
420 it
.value().kind
!= LinalgOperandDefKind::Scalar
&&
421 it
.value().kind
!= LinalgOperandDefKind::OutputTensor
)
423 if (*it
.value().typeVar
== typeVar
)
424 return llvm::formatv("block.getArgument({0}).getType()", it
.index())
431 static ScalarAssign
*findAssignment(StringRef name
,
432 std::vector
<ScalarAssign
> &assignments
) {
433 for (auto &assign
: assignments
) {
434 if (assign
.arg
== name
)
440 // Return true if the operand is a function attribute.
441 static bool isFunctionAttribute(LinalgOperandDefKind kind
) {
442 return kind
== LinalgOperandDefKind::UnaryFnAttr
||
443 kind
== LinalgOperandDefKind::BinaryFnAttr
||
444 kind
== LinalgOperandDefKind::TypeFnAttr
;
447 // Return true if the operand is an attribute.
448 static bool isAttribute(LinalgOperandDefKind kind
) {
449 return kind
== LinalgOperandDefKind::IndexAttr
|| isFunctionAttribute(kind
);
452 // Get the enum name for the given operand kind.
453 std::string
convertOperandKindToEnumName(LinalgOperandDefKind kind
) {
455 case LinalgOperandDefKind::UnaryFnAttr
:
456 return std::string("UnaryFn");
457 case LinalgOperandDefKind::BinaryFnAttr
:
458 return std::string("BinaryFn");
459 case LinalgOperandDefKind::TypeFnAttr
:
460 return std::string("TypeFn");
464 llvm_unreachable("unsupported function attribute kind");
467 // Get the enum name for the given function kind.
468 std::string
convertFunctionKindToEnumName(ScalarFnKind kind
) {
470 case ScalarFnKind::Unary
:
471 return std::string("UnaryFn");
472 case ScalarFnKind::Binary
:
473 return std::string("BinaryFn");
474 case ScalarFnKind::Type
:
475 return std::string("TypeFn");
477 llvm_unreachable("unsupported function kind");
480 //===----------------------------------------------------------------------===//
482 //===----------------------------------------------------------------------===//
484 // A single line banner format. Parameters:
485 // {0}: Single line comment
486 static const char bannerFormat
[] = R
"FMT(
487 //===----------------------------------------------------------------------===//
489 //===----------------------------------------------------------------------===//
492 //===----------------------------------------------------------------------===//
493 // Named generic op generation.
494 // These ops map at most a single contraction that complies with the limitations
495 // of a linalg.generic.
496 //===----------------------------------------------------------------------===//
498 // Template for Linalg named ops' ODS definitions. Parameters:
499 // {0}: ODS/C++ op name
500 // {1}: assembly op mnemonic
501 // {2}: op interface list
502 // {3}: documentation (summary + description)
503 // {4}: op attribute list
504 // {5}: builder methods taking standalone attribute parameters
505 // {6}: additional method defintions
506 // {7}: additional methods for attributes used by indexing maps
507 static const char structuredOpOdsHeaderFormat
[] = R
"FMT(
508 //===----------------------------------------------------------------------===//
509 // Op definition for {0}
510 //===----------------------------------------------------------------------===//
512 def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
513 /*extraInterfaces=*/[{2}])> {
516 Variadic<AnyType>:$inputs,
517 Variadic<AnyShaped>:$outputs{4}
519 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
520 let regions = (region AnyRegion:$region);
522 let skipDefaultBuilders = 1;
525 (ins "ValueRange
":$inputs, "ValueRange
":$outputs,
526 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
528 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
529 attributes, {0}::getRegionBuilder());
532 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
533 "ValueRange
":$outputs,
534 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
536 buildStructuredOp($_builder, $_state, resultTensorTypes,
537 inputs, outputs, attributes, {0}::getRegionBuilder());
540 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$operands,
541 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
543 $_state.addOperands(operands);
544 $_state.addAttributes(attributes);
545 $_state.addTypes(resultTensorTypes);
546 (void)$_state.addRegion();
550 let hasCustomAssemblyFormat = 1;
554 let extraClassDeclaration = structuredOpsBaseDecls # [{{
556 SmallVector<utils::IteratorType> getIteratorTypesArray();
557 ArrayAttr getIndexingMaps();
558 static void regionBuilder(ImplicitLocOpBuilder &b,
559 Block &block, ArrayRef<NamedAttribute> attrs);
560 static std::function<void(ImplicitLocOpBuilder &,
561 Block &, ArrayRef<NamedAttribute>)>
562 getRegionBuilder() {{
563 return regionBuilder;
566 std::pair<int64_t, int64_t> getDpsInitsPositionRange() {{
567 int64_t getNumOperands = this->getNumOperands();
568 return {{getNumOperands - 1, getNumOperands};
572 static unsigned getNumRegionArgs();
573 std::string getLibraryCallName();
579 // Builder method taking attribute parameters. Parameters:
581 // {1}: Comma interleaved attribute parameters
582 // {2}: Attribute initialization
583 static const char structuredOpBuilderFormat
[] = R
"FMT(
585 (ins "TypeRange
":$resultTensorTypes, "ValueRange
":$inputs,
586 "ValueRange
":$outputs, {1},
587 CArg<"ArrayRef
<NamedAttribute
>", "{{}">:$attributes),
590 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
591 attributes, {0}::getRegionBuilder());
595 // The getIteratorTypesArray() method for structured ops. Parameters:
597 // {1}: Comma interleaved iterator type names.
598 static const char structuredOpIteratorTypesFormat
[] =
600 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
601 return SmallVector<utils::IteratorType>{{ {1} };
605 // The getIteratorTypesArray() method for rank polymorphic structured ops.
608 static const char rankPolyStructuredOpIteratorTypesFormat
[] =
610 SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
611 int64_t rank = getRank(getDpsInitOperand(0));
612 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
616 // The indexing_maps() method for structured ops. Parameters:
618 // {1}: Comma-separated list of dimension variable names.
620 static const char structuredOpIndexingMapsFormat
[] = R
"FMT(
621 ArrayAttr {0}::getIndexingMaps() {{
622 static const char memoizeAttr[] = "linalg
.memoized_indexing_maps
";
623 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
627 MLIRContext *context = getContext();
628 auto symbolBindings = getSymbolBindings(*this);
629 SmallVector<AffineMap> maps;
631 cached = Builder(context).getAffineMapArrayAttr(maps);
632 getOperation()->setAttr(memoizeAttr, cached);
637 // The indexing_maps() method for rank polymorphic structured ops. Parameters:
639 static const char rankPolyStructuredOpIndexingMapsFormat
[] = R
"FMT(
640 ArrayAttr {0}::getIndexingMaps() {{
641 MLIRContext *context = getContext();
642 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
643 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
644 getNumParallelLoops(), context);
645 SmallVector<AffineMap> indexingMaps;
646 for (OpOperand &opOperand : getOperation()->getOpOperands())
647 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
648 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
652 // Implementations of fold and getEffects.
655 const char structuredOpFoldersFormat
[] = R
"FMT(
656 LogicalResult {0}::fold(FoldAdaptor,
657 SmallVectorImpl<OpFoldResult> &) {{
658 return memref::foldMemRefCast(*this);
660 void {0}::getEffects(SmallVectorImpl<
661 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
662 if (hasTensorSemantics()) return;
663 getGenericEffectsImpl(effects,
664 getOperation()->getResults(), getDpsInputOperands(), getDpsInitOperands());
668 // Implementation of parse/print.
671 static const char structuredOpParserFormat
[] = R
"FMT(
672 ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
673 return ::parseNamedStructuredOp(parser, result,
674 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
676 void {0}::print(OpAsmPrinter &p) {{
677 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
681 static LogicalResult
generateNamedGenericOpOds(LinalgOpConfig
&opConfig
,
682 GenerationContext
&genContext
) {
683 if (!genContext
.shouldGenerateOds())
686 raw_ostream
&os
= genContext
.odss();
688 std::string interfaceNameList
;
689 std::string attrList
;
690 std::string attrMethods
;
691 std::string attrBuilder
;
694 if (opConfig
.metadata
->doc
) {
695 static const char structuredOpDocFmt
[] = R
"FMT(
696 let summary = [{ {0} }];
701 StringRef summary
, description
;
702 std::tie(summary
, description
) =
703 StringRef(*opConfig
.metadata
->doc
).trim().split('\n');
704 doc
= llvm::formatv(structuredOpDocFmt
, summary
.trim(), description
.trim());
707 interfaceNameList
= interleaveToString(opConfig
.metadata
->implements
, ", ");
709 std::string definitionList
;
710 for (const std::string
&definition
: opConfig
.metadata
->defines
) {
711 static const char definitionFmt
[] = "let {0} = 1;\n";
712 definitionList
.append(llvm::formatv(definitionFmt
, definition
));
715 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
716 return isAttribute(arg
.kind
);
718 SmallVector
<std::string
> attrDefs
;
719 SmallVector
<std::string
> attrParams
;
720 SmallVector
<std::string
> attrStmts
;
721 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
722 static const char paramFmt
[] = "\"Attribute\":${0}";
723 static const char stmtFmt
[] = "$_state.addAttribute(\"{0}\", {0});";
724 // Add the type conversion attributes to the op definition and builders.
725 if (isFunctionAttribute(arg
.kind
)) {
726 assert(arg
.defaultFn
);
727 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
728 static const char typeFmt
[] = "{0}::{1}";
729 static const char defFmt
[] =
730 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
731 attrDefs
.push_back(llvm::formatv(
732 defFmt
, llvm::formatv("{0}Attr", enumName
),
733 llvm::formatv(typeFmt
, enumName
, arg
.defaultFn
), arg
.name
));
734 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
735 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
737 // Add the index attributes to the op definition and builders.
738 if (arg
.kind
== LinalgOperandDefKind::IndexAttr
) {
739 assert(arg
.indexAttrMap
.has_value());
740 assert(arg
.defaultIndices
.has_value());
741 size_t size
= arg
.indexAttrMap
->affineMap().getNumResults();
742 assert(arg
.defaultIndices
->size() == size
);
743 static const char typeFmt
[] = "RankedI64ElementsAttr<[{0}]>";
744 static const char defFmt
[] =
745 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
746 std::string defaultVals
;
747 llvm::raw_string_ostream
ss(defaultVals
);
749 *arg
.defaultIndices
, ss
,
750 [&](int64_t val
) { ss
<< "static_cast<int64_t>(" << val
<< ")"; },
752 attrDefs
.push_back(llvm::formatv(defFmt
, llvm::formatv(typeFmt
, size
),
753 ss
.str(), arg
.name
));
754 attrParams
.push_back(llvm::formatv(paramFmt
, arg
.name
));
755 attrStmts
.push_back(llvm::formatv(stmtFmt
, arg
.name
));
758 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
759 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
762 bool hasDynamicIndexingMaps();
763 LogicalResult verifyIndexingMapRequiredAttributes();
766 attrList
= ",\n" + llvm::join(attrDefs
, ",\n");
767 attrBuilder
= llvm::formatv(
768 structuredOpBuilderFormat
, opConfig
.metadata
->cppClassName
,
769 llvm::join(attrParams
, ", "), llvm::join(attrStmts
, "\n"));
772 os
<< llvm::formatv(structuredOpOdsHeaderFormat
,
773 opConfig
.metadata
->cppClassName
, opConfig
.metadata
->name
,
774 interfaceNameList
, doc
, attrList
, attrBuilder
,
775 definitionList
, attrMethods
);
781 generateNamedGenericOpDefns(LinalgOpConfig
&opConfig
,
782 GenerationContext
&genContext
) {
783 if (!genContext
.shouldGenerateDefns())
786 raw_ostream
&os
= genContext
.defns();
787 StringRef className
= opConfig
.metadata
->cppClassName
;
789 // Implementation banner.
790 std::string bannerComment
= llvm::formatv("Implementation of {0}", className
);
791 os
<< llvm::formatv(bannerFormat
, bannerComment
);
793 // Compute the number of scalar and tensor arguments.
795 llvm::count_if(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
796 return arg
.kind
== LinalgOperandDefKind::InputTensor
||
797 arg
.kind
== LinalgOperandDefKind::Scalar
||
798 arg
.kind
== LinalgOperandDefKind::OutputTensor
;
801 // An operation that accesses only scalars and scalar/rank zero tensors is
802 // rank polymorhpic. We implement rank polymorphism by generating different
803 // indexing maps and iterators that match the rank of the first output tensor.
804 // An operation is rank polymorphic if the iteration domain has rank zero.
805 bool isRankPolymorphic
= opConfig
.structuredOp
->iteratorTypes
.empty();
807 // Generate the iterator_types() method.
808 if (!isRankPolymorphic
) {
809 std::string iteratorsStr
;
810 llvm::raw_string_ostream
ss(iteratorsStr
);
811 llvm::interleaveComma(opConfig
.structuredOp
->iteratorTypes
, ss
,
812 [&](LinalgIteratorTypeDef it
) {
814 case LinalgIteratorTypeDef::parallel
:
815 ss
<< "utils::IteratorType::parallel";
817 case LinalgIteratorTypeDef::reduction
:
818 ss
<< "utils::IteratorType::reduction";
823 os
<< llvm::formatv(structuredOpIteratorTypesFormat
, className
,
826 os
<< llvm::formatv(rankPolyStructuredOpIteratorTypesFormat
, className
);
829 // Generating the getIndexingMaps() method.
830 if (auto &staticMaps
=
831 opConfig
.structuredOp
->indexingMaps
.staticIndexingMaps
) {
832 if (staticMaps
->empty())
833 return emitError(genContext
.getLoc()) << "op has no indexing maps";
834 if (!isRankPolymorphic
) {
835 AffineMap firstMap
= staticMaps
->front().affineMap();
839 // For each symbol, generate a declaration for it, either with an
840 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
842 // TODO: Possibly lift into a top-level method.
843 static const char structuredOpSymbolBindingsFormat
[] = R
"FMT(
844 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
845 MLIRContext *context = self.getContext();
846 SmallVector<AffineExpr> exprs;
852 unsigned symbolCount
= firstMap
.getNumSymbols();
853 SmallVector
<std::string
> symbolBindings
;
854 for (unsigned i
= 0; i
< symbolCount
; ++i
) {
855 symbolBindings
.push_back(llvm::formatv(
856 " exprs.push_back(getAffineSymbolExpr({0}, context));", i
));
859 // Access an index attribute. Parameters:
860 // {0}: Attribute name
861 // {1}: Symbol position
862 // {2}: Attribute index
863 static const char structuredOpAccessAttrFormat
[] = R
"FMT(
864 int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
865 exprs.push_back(getAffineConstantExpr(cst{1}, context));
867 // Update all symbol bindings mapped to an attribute.
868 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
869 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
871 assert(arg
.indexAttrMap
);
872 for (auto [idx
, result
] :
873 llvm::enumerate(arg
.indexAttrMap
->affineMap().getResults())) {
874 if (auto symbol
= result
.dyn_cast
<AffineSymbolExpr
>()) {
875 std::string argName
= arg
.name
;
876 argName
[0] = toupper(argName
[0]);
877 symbolBindings
[symbol
.getPosition()] =
878 llvm::formatv(structuredOpAccessAttrFormat
, argName
,
879 symbol
.getPosition(), idx
);
884 std::string symbolBindingsStr
;
885 llvm::raw_string_ostream
symbolBindingsSs(symbolBindingsStr
);
886 llvm::interleave(symbolBindings
, symbolBindingsSs
, "\n");
887 symbolBindingsSs
.flush();
889 os
<< llvm::formatv(structuredOpSymbolBindingsFormat
, className
,
895 unsigned dimCount
= firstMap
.getNumDims();
897 // Generate a comma-separated list of dim identifiers to be passed to
898 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
899 // order to the proper AffineDimExpr.
900 // This results in vars in scope like: d0, d1, d2...
901 SmallVector
<unsigned> dimIndices
;
902 for (unsigned i
= 0; i
< dimCount
; ++i
)
903 dimIndices
.push_back(i
);
904 std::string dimIdentsStr
;
905 llvm::raw_string_ostream
dimIdentsSs(dimIdentsStr
);
906 llvm::interleaveComma(dimIndices
, dimIdentsSs
,
907 [&](unsigned i
) { dimIdentsSs
<< "d" << i
; });
910 // Statements to add and simplify each affine map.
911 SmallVector
<std::string
> stmts
;
912 for (auto &indexingMap
: *staticMaps
) {
913 // TODO: Assert that dim and symbol count match the first.
915 llvm::formatv("maps.push_back({0});",
916 generateCppExpression(indexingMap
, "context")));
917 stmts
.push_back(llvm::formatv(
919 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
920 "symbolBindings, {0}, 0));",
924 // TODO: This needs to be memoized and/or converted to non-parser based
925 // C++ codegen prior to real use.
926 os
<< llvm::formatv(structuredOpIndexingMapsFormat
, className
,
927 dimIdentsStr
, interleaveToString(stmts
, "\n "));
930 os
<< llvm::formatv(rankPolyStructuredOpIndexingMapsFormat
, className
);
933 return emitError(genContext
.getLoc())
934 << "generating code for non static indexing maps not currently "
938 // getNumRegionArgs()
940 // Generates a getNumRegionArgs() method. Parameters:
942 // {1}: Number of region args
943 static const char structuredOpGetNumRegionArgsFormat
[] = R
"FMT(
944 unsigned {0}::getNumRegionArgs() {{ return {1}; }
946 os
<< llvm::formatv(structuredOpGetNumRegionArgsFormat
, className
,
950 // getLibraryCallName()
952 // Generates a getLibraryCallName method. Parameters:
954 static const char structuredOpGetLibraryCallFormat
[] = R
"FMT(
955 std::string {0}::getLibraryCallName() {{
956 return generateLibraryCallName(getOperation());
959 os
<< llvm::formatv(structuredOpGetLibraryCallFormat
, className
);
962 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
963 if (llvm::any_of(opConfig
.structuredOp
->args
, [](LinalgOperandDef
&arg
) {
964 return arg
.kind
== LinalgOperandDefKind::IndexAttr
;
966 std::vector
<std::string
> attrVerifications
;
967 for (LinalgOperandDef
&arg
: opConfig
.structuredOp
->args
) {
968 if (arg
.kind
!= LinalgOperandDefKind::IndexAttr
)
970 assert(arg
.indexAttrMap
);
971 // Verify index attribute. Paramters:
972 // {0}: Attribute name
973 // {1}: Attribute size
974 static const char attrFmt
[] = R
"FMT(
975 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
976 if (!attr.getType().getElementType().isInteger(64))
977 return op->emitError("incorrect element type
for index attribute
'{0}'");
978 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
979 return op->emitError("incorrect shape
for index attribute
'{0}'");
982 attrVerifications
.push_back(llvm::formatv(
983 attrFmt
, arg
.name
, arg
.indexAttrMap
->affineMap().getNumResults()));
986 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
988 // {1}: Attribute verification
989 static const char structuredOpVerifyIndexingMapRequiredAttributes
[] = R
"FMT(
990 bool {0}::hasDynamicIndexingMaps() {{ return true; }
991 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
992 Operation *op = getOperation();
997 os
<< llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes
,
998 className
, llvm::join(attrVerifications
, "\n"));
1003 // Generates a regionBuilder method. Parameters.
1005 // {1}: Number of args
1008 static const char structuredOpRegionBuilderFormat
[] = R
"FMT(
1009 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1010 Block &block, ArrayRef<NamedAttribute> attrs) {{
1011 assert({1} > 0 && block.getNumArguments() == {1} &&
1012 "{0} regionBuilder expects
{1} (>=0) args
");
1013 RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
1014 SmallVector<Value> yields;
1017 helper.yieldOutputs(yields);
1020 auto &args
= opConfig
.structuredOp
->args
;
1021 auto &assignments
= opConfig
.structuredOp
->assignments
;
1022 size_t generatedAssignmentCount
= 0;
1023 int localCounter
= 0;
1024 SmallVector
<std::string
> attrs
;
1025 SmallVector
<std::string
> stmts
;
1026 for (LinalgOperandDef
&arg
: args
) {
1027 if (!isFunctionAttribute(arg
.kind
))
1029 // Obtain the type function attribute values. Parameters.
1031 // {1}: attribute name
1032 // {2}: default type function name
1033 static const char attrDef
[] = R
"FMT(
1034 {0} {1}Val = {0}::{2};
1035 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1036 return attr.getName() == "{1}"; });
1037 if ({1}Iter != attrs.end()) {{
1038 if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
1039 {1}Val = attr.getValue();
1042 std::string enumName
= convertOperandKindToEnumName(arg
.kind
);
1044 llvm::formatv(attrDef
, enumName
, arg
.name
, arg
.defaultFn
));
1046 for (LinalgOperandDef
&arg
: args
) {
1047 if (arg
.kind
!= LinalgOperandDefKind::OutputTensor
)
1050 // Find the assignment that correlates with the argument.
1051 ScalarAssign
*assignment
= findAssignment(arg
.name
, assignments
);
1053 return emitError(genContext
.getLoc())
1054 << "no assignment found for output argument " << arg
.name
;
1055 ++generatedAssignmentCount
;
1057 // Recursively generate the expression.
1058 std::function
<std::optional
<std::string
>(ScalarExpression
&)>
1059 generateExpression
=
1060 [&](ScalarExpression
&expression
) -> std::optional
<std::string
> {
1061 if (expression
.arg
) {
1062 // Argument reference.
1063 std::optional
<int> argIndex
=
1064 findTensorDefArgIndex(*expression
.arg
, args
);
1066 emitError(genContext
.getLoc())
1067 << "scalar argument not defined on the op: " << *expression
.arg
;
1068 return std::nullopt
;
1071 llvm::formatv("block.getArgument({0})", *argIndex
));
1073 if (expression
.constant
) {
1074 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1076 llvm::formatv(R
"FMT(Value {0} = helper.constant("{1}");)FMT",
1077 cppIdent
, expression
.constant
));
1080 if (expression
.index
) {
1081 // Access an iteration index.
1082 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1083 stmts
.push_back(llvm::formatv("Value {0} = helper.index({1});",
1084 cppIdent
, *expression
.index
));
1087 if (expression
.scalarFn
) {
1088 std::string enumName
=
1089 convertFunctionKindToEnumName(expression
.scalarFn
->kind
);
1091 // Get the function or attribute name.
1092 assert(expression
.scalarFn
->fnName
|| expression
.scalarFn
->attrName
);
1093 std::string funcType
;
1094 if (expression
.scalarFn
->fnName
) {
1095 funcType
= llvm::formatv("{0}::{1}", enumName
,
1096 *expression
.scalarFn
->fnName
);
1098 if (expression
.scalarFn
->attrName
) {
1099 if (llvm::none_of(args
, [&](LinalgOperandDef
&arg
) {
1100 return isFunctionAttribute(arg
.kind
) &&
1101 arg
.name
== *expression
.scalarFn
->attrName
;
1103 emitError(genContext
.getLoc()) << "missing function attribute "
1104 << *expression
.scalarFn
->attrName
;
1106 funcType
= llvm::formatv("{0}Val", *expression
.scalarFn
->attrName
);
1108 assert(!funcType
.empty());
1110 // Add the optional type parameter to the operands.
1111 SmallVector
<std::string
> operandCppValues
;
1112 if (expression
.scalarFn
->kind
== ScalarFnKind::Type
) {
1113 assert(expression
.scalarFn
->typeVar
.has_value());
1114 std::optional
<std::string
> typeCppValue
=
1115 findTypeValue(*expression
.scalarFn
->typeVar
, args
);
1116 if (!typeCppValue
) {
1117 emitError(genContext
.getLoc())
1118 << "type variable " << *expression
.scalarFn
->typeVar
1119 << ", used in a type conversion, must map to a predefined or "
1120 << "an argument type but it does not";
1121 return std::nullopt
;
1123 operandCppValues
.push_back(*typeCppValue
);
1126 // Collect the scalar operands.
1127 for (ScalarExpression
&operand
: expression
.scalarFn
->operands
) {
1128 auto operandCppValue
= generateExpression(operand
);
1129 if (!operandCppValue
)
1130 return std::nullopt
;
1131 operandCppValues
.push_back(*operandCppValue
);
1134 // Call the function builder.
1135 std::string cppIdent
= llvm::formatv("value{0}", ++localCounter
);
1136 stmts
.push_back(llvm::formatv(
1137 "Value {0} = helper.build{1}({2}, {3});", cppIdent
, enumName
,
1138 funcType
, interleaveToString(operandCppValues
, ", ")));
1141 emitError(genContext
.getLoc()) << "unknown ScalarExpression type";
1142 return std::nullopt
;
1144 std::optional
<std::string
> cppValue
=
1145 generateExpression(assignment
->value
);
1148 stmts
.push_back(llvm::formatv("yields.push_back({0});", *cppValue
));
1151 if (generatedAssignmentCount
!= assignments
.size())
1152 return emitError(genContext
.getLoc())
1153 << "mismatched number of assignments vs output arguments";
1155 os
<< llvm::formatv(structuredOpRegionBuilderFormat
, className
, numOfArgs
,
1156 interleaveToString(attrs
, "\n "),
1157 interleaveToString(stmts
, "\n "));
1160 // Parser and printer.
1161 os
<< llvm::formatv(structuredOpParserFormat
, className
);
1163 // Canonicalizers and folders.
1164 os
<< llvm::formatv(structuredOpFoldersFormat
, className
);
1169 static LogicalResult
generateOp(LinalgOpConfig
&opConfig
,
1170 GenerationContext
&genContext
) {
1171 // Switch on op type being generated.
1172 if (opConfig
.structuredOp
) {
1174 succeeded(generateNamedGenericOpOds(opConfig
, genContext
)) &&
1175 succeeded(generateNamedGenericOpDefns(opConfig
, genContext
)));
1177 return emitError(genContext
.getLoc()) << "unsupported operation type";
1180 //===----------------------------------------------------------------------===//
1181 // Command line options and main
1182 //===----------------------------------------------------------------------===//
1184 static llvm::cl::opt
<std::string
>
1185 inputFilename(llvm::cl::Positional
, llvm::cl::desc("<input file>"),
1186 llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
1188 static llvm::cl::opt
<std::string
>
1189 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1190 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1192 static llvm::cl::opt
<std::string
>
1193 outputCppImplFilename("o-impl",
1194 llvm::cl::desc("C++ implementation file name"),
1195 llvm::cl::value_desc("filename"), llvm::cl::init(""));
1197 int main(int argc
, char **argv
) {
1198 llvm::cl::ParseCommandLineOptions(argc
, argv
, "Linalg ODS Gen from YAML");
1200 // Set up the input file.
1201 std::string errorMessage
;
1202 std::unique_ptr
<llvm::MemoryBuffer
> file
=
1203 mlir::openInputFile(inputFilename
, &errorMessage
);
1205 llvm::errs() << errorMessage
<< "\n";
1209 MLIRContext mlirContext
;
1210 LinalgYAMLContext yamlContext
{&mlirContext
};
1212 std::vector
<LinalgOpConfig
> opConfigs
;
1215 Input
yin(file
->getBuffer(), &yamlContext
);
1221 // Open output files.
1222 std::unique_ptr
<llvm::ToolOutputFile
> outputOdsDecl
;
1223 if (!outputOdsDeclFilename
.empty()) {
1224 outputOdsDecl
= openOutputFile(outputOdsDeclFilename
, &errorMessage
);
1225 if (!outputOdsDecl
) {
1226 llvm::errs() << errorMessage
<< "\n";
1231 std::unique_ptr
<llvm::ToolOutputFile
> outputCppImpl
;
1232 if (!outputCppImplFilename
.empty()) {
1233 outputCppImpl
= openOutputFile(outputCppImplFilename
, &errorMessage
);
1234 if (!outputCppImpl
) {
1235 llvm::errs() << errorMessage
<< "\n";
1240 if (!outputOdsDecl
&& !outputCppImpl
) {
1241 llvm::errs() << "error: No output files specified\n";
1246 GenerationContext
genContext(&mlirContext
,
1247 outputOdsDecl
? &outputOdsDecl
->os() : nullptr,
1248 outputCppImpl
? &outputCppImpl
->os() : nullptr);
1250 for (auto &opConfig
: opConfigs
) {
1251 if (!opConfig
.metadata
) {
1252 emitError(genContext
.getLoc())
1253 << "missing operation metadata on subsequent op";
1257 genContext
.setLoc(NameLoc::get(
1258 StringAttr::get(&mlirContext
, opConfig
.metadata
->cppClassName
)));
1259 if (failed(generateOp(opConfig
, genContext
))) {
1265 outputOdsDecl
->keep();
1267 outputCppImpl
->keep();