1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Pass/Pass.h"
21 #define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
22 #include "mlir/Conversion/Passes.h.inc"
26 using namespace mlir::linalg
;
28 static MemRefType
makeStridedLayoutDynamic(MemRefType type
) {
29 return MemRefType::Builder(type
).setLayout(StridedLayoutAttr::get(
30 type
.getContext(), ShapedType::kDynamic
,
31 SmallVector
<int64_t>(type
.getRank(), ShapedType::kDynamic
)));
34 /// Helper function to extract the operand types that are passed to the
35 /// generated CallOp. MemRefTypes have their layout canonicalized since the
36 /// information is not used in signature generation.
37 /// Note that static size information is not modified.
38 static SmallVector
<Type
, 4> extractOperandTypes(Operation
*op
) {
39 SmallVector
<Type
, 4> result
;
40 result
.reserve(op
->getNumOperands());
41 for (auto type
: op
->getOperandTypes()) {
42 // The underlying descriptor type (e.g. LLVM) does not have layout
43 // information. Canonicalizing the type at the level of std when going into
44 // a library call avoids needing to introduce DialectCastOp.
45 if (auto memrefType
= dyn_cast
<MemRefType
>(type
))
46 result
.push_back(makeStridedLayoutDynamic(memrefType
));
48 result
.push_back(type
);
53 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
54 // If the library function does not exist, insert a declaration.
55 static FailureOr
<FlatSymbolRefAttr
>
56 getLibraryCallSymbolRef(Operation
*op
, PatternRewriter
&rewriter
) {
57 auto linalgOp
= cast
<LinalgOp
>(op
);
58 auto fnName
= linalgOp
.getLibraryCallName();
60 return rewriter
.notifyMatchFailure(op
, "No library call defined for: ");
62 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
63 FlatSymbolRefAttr fnNameAttr
=
64 SymbolRefAttr::get(rewriter
.getContext(), fnName
);
65 auto module
= op
->getParentOfType
<ModuleOp
>();
66 if (module
.lookupSymbol(fnNameAttr
.getAttr()))
69 SmallVector
<Type
, 4> inputTypes(extractOperandTypes(op
));
70 if (op
->getNumResults() != 0) {
71 return rewriter
.notifyMatchFailure(
73 "Library call for linalg operation can be generated only for ops that "
74 "have void return types");
76 auto libFnType
= rewriter
.getFunctionType(inputTypes
, {});
78 OpBuilder::InsertionGuard
guard(rewriter
);
79 // Insert before module terminator.
80 rewriter
.setInsertionPoint(module
.getBody(),
81 std::prev(module
.getBody()->end()));
82 func::FuncOp funcOp
= rewriter
.create
<func::FuncOp
>(
83 op
->getLoc(), fnNameAttr
.getValue(), libFnType
);
84 // Insert a function attribute that will trigger the emission of the
85 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
86 // a normalized ABI. This interface is added during std to llvm conversion.
87 funcOp
->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
88 UnitAttr::get(op
->getContext()));
93 static SmallVector
<Value
, 4>
94 createTypeCanonicalizedMemRefOperands(OpBuilder
&b
, Location loc
,
95 ValueRange operands
) {
96 SmallVector
<Value
, 4> res
;
97 res
.reserve(operands
.size());
98 for (auto op
: operands
) {
99 auto memrefType
= dyn_cast
<MemRefType
>(op
.getType());
105 b
.create
<memref::CastOp
>(loc
, makeStridedLayoutDynamic(memrefType
), op
);
111 LogicalResult
mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
112 LinalgOp op
, PatternRewriter
&rewriter
) const {
113 auto libraryCallName
= getLibraryCallSymbolRef(op
, rewriter
);
114 if (failed(libraryCallName
))
117 // TODO: Add support for more complex library call signatures that include
118 // indices or captured values.
119 rewriter
.replaceOpWithNewOp
<func::CallOp
>(
120 op
, libraryCallName
->getValue(), TypeRange(),
121 createTypeCanonicalizedMemRefOperands(rewriter
, op
->getLoc(),
126 /// Populate the given list with patterns that convert from Linalg to Standard.
127 void mlir::linalg::populateLinalgToStandardConversionPatterns(
128 RewritePatternSet
&patterns
) {
129 // TODO: ConvOp conversion needs to export a descriptor with relevant
130 // attribute values such as kernel striding and dilation.
131 patterns
.add
<LinalgOpToLibraryCallRewrite
>(patterns
.getContext());
135 struct ConvertLinalgToStandardPass
136 : public impl::ConvertLinalgToStandardBase
<ConvertLinalgToStandardPass
> {
137 void runOnOperation() override
;
141 void ConvertLinalgToStandardPass::runOnOperation() {
142 auto module
= getOperation();
143 ConversionTarget
target(getContext());
144 target
.addLegalDialect
<affine::AffineDialect
, arith::ArithDialect
,
145 func::FuncDialect
, memref::MemRefDialect
,
147 target
.addLegalOp
<ModuleOp
, func::FuncOp
, func::ReturnOp
>();
148 RewritePatternSet
patterns(&getContext());
149 populateLinalgToStandardConversionPatterns(patterns
);
150 if (failed(applyFullConversion(module
, target
, std::move(patterns
))))
154 std::unique_ptr
<OperationPass
<ModuleOp
>>
155 mlir::createConvertLinalgToStandardPass() {
156 return std::make_unique
<ConvertLinalgToStandardPass
>();