1 //===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the Linalg dialect types and dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
25 #include "mlir/Interfaces/FunctionInterfaces.h"
26 #include "mlir/Interfaces/SubsetOpInterface.h"
27 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
28 #include "mlir/Parser/Parser.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/InliningUtils.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
37 using namespace mlir::linalg
;
39 //===----------------------------------------------------------------------===//
40 // LinalgDialect Dialect Interfaces
41 //===----------------------------------------------------------------------===//
45 struct LinalgInlinerInterface
: public DialectInlinerInterface
{
46 using DialectInlinerInterface::DialectInlinerInterface
;
48 // We don't have any special restrictions on what can be inlined into
49 // destination regions (e.g. while/conditional bodies). Always allow it.
50 bool isLegalToInline(Region
*dest
, Region
*src
, bool wouldBeCloned
,
51 IRMapping
&valueMapping
) const final
{
54 // Operations in Linalg dialect are always legal to inline.
55 bool isLegalToInline(Operation
*, Region
*, bool, IRMapping
&) const final
{
58 // Handle the given inlined terminator by replacing it with a new operation
59 // as necessary. Required when the region has only one block.
60 void handleTerminator(Operation
*op
, ValueRange valuesToRepl
) const final
{}
65 //===----------------------------------------------------------------------===//
67 //===----------------------------------------------------------------------===//
69 /// Attribute name used to memoize indexing maps for named ops.
70 constexpr const ::llvm::StringLiteral
71 LinalgDialect::kMemoizedIndexingMapsAttrName
;
73 /// Trait to check if T provides a `regionBuilder` method.
74 template <typename T
, typename
... Args
>
75 using has_region_builder
= decltype(T::regionBuilder
);
77 using detect_has_region_builder
= llvm::is_detected
<has_region_builder
, T
>;
79 /// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
81 template <typename OpType
, typename
= std::enable_if_t
<
82 !detect_has_region_builder
<OpType
>::value
>>
83 void addNamedOpBuilderImpl(
84 llvm::StringMap
<LinalgDialect::RegionBuilderFunType
> &map
) {
88 template <typename OpType
,
89 typename
= std::enable_if_t
<detect_has_region_builder
<OpType
>::value
>,
91 void addNamedOpBuilderImpl(
92 llvm::StringMap
<LinalgDialect::RegionBuilderFunType
> &map
) {
93 map
.insert(std::make_pair(
94 OpType::getOperationName(),
95 static_cast<LinalgDialect::RegionBuilderFunType
>(OpType::regionBuilder
)));
98 template <typename
... OpTypes
>
99 void addNamedOpBuilders(
100 llvm::StringMap
<LinalgDialect::RegionBuilderFunType
> &map
) {
101 (addNamedOpBuilderImpl
<OpTypes
>(map
), ...);
104 void mlir::linalg::LinalgDialect::initialize() {
106 #define GET_ATTRDEF_LIST
107 #include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
111 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
115 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
118 // Fill the Linalg-specific OpName to RegionBuilder map.
121 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
122 >(namedStructuredOpRegionBuilders
);
124 addInterfaces
<LinalgInlinerInterface
>();
126 declarePromisedInterface
<mesh::ShardingInterface
, GenericOp
>();
127 declarePromisedInterfaces
<mesh::ShardingInterface
,
129 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
131 declarePromisedInterface
<SubsetOpInterface
, CopyOp
>();
132 declarePromisedInterface
<SubsetInsertionOpInterface
, CopyOp
>();
133 declarePromisedInterface
<ValueBoundsOpInterface
, IndexOp
>();
134 declarePromisedInterface
<TilingInterface
, linalg::GenericOp
>();
135 declarePromisedInterface
<PartialReductionOpInterface
, linalg::GenericOp
>();
136 declarePromisedInterfaces
<TilingInterface
,
138 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
140 declarePromisedInterfaces
<PartialReductionOpInterface
,
142 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
144 declarePromisedInterfaces
<bufferization::BufferizableOpInterface
,
146 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
150 LogicalResult
LinalgDialect::verifyOperationAttribute(Operation
*op
,
151 NamedAttribute attr
) {
152 if (attr
.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName
)
154 return op
->emitError() << "attribute '" << attr
.getName()
155 << "' not supported by the linalg dialect";
158 #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc"
160 #define GET_ATTRDEF_CLASSES
161 #include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.cpp.inc"
163 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"