1 //===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass for testing the transformation to drop unit
10 // extent dimensions from `linalg.generic` operations.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 LogicalResult
dropOutermostUnitDims(RewriterBase
&rewriter
,
25 linalg::GenericOp genericOp
) {
26 linalg::ControlDropUnitDims options
;
27 options
.controlFn
= [](Operation
*op
) { return SmallVector
<unsigned>{0}; };
28 FailureOr
<linalg::DropUnitDimsResult
> result
=
29 linalg::dropUnitDims(rewriter
, genericOp
, options
);
33 rewriter
.replaceOp(genericOp
, result
->replacements
);
37 struct TestLinalgDropUnitDims
38 : public PassWrapper
<TestLinalgDropUnitDims
, OperationPass
<func::FuncOp
>> {
40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims
)
42 TestLinalgDropUnitDims() = default;
43 TestLinalgDropUnitDims(const TestLinalgDropUnitDims
&pass
) = default;
45 void getDependentDialects(DialectRegistry
®istry
) const override
{
46 registry
.insert
<linalg::LinalgDialect
>();
49 StringRef
getArgument() const final
{ return "test-linalg-drop-unit-dims"; }
51 StringRef
getDescriptions() const {
52 return "Test transformation to drop unit-extent dims from Linalg "
56 void runOnOperation() override
{
57 MLIRContext
*context
= &this->getContext();
58 func::FuncOp funcOp
= this->getOperation();
59 IRRewriter
rewriter(context
);
60 SmallVector
<linalg::GenericOp
> genericOps
;
62 [&](linalg::GenericOp genericOp
) { genericOps
.push_back(genericOp
); });
64 for (auto genericOp
: genericOps
) {
65 rewriter
.setInsertionPoint(genericOp
);
66 (void)dropOutermostUnitDims(rewriter
, genericOp
);
74 void registerTestLinalgDropUnitDims() {
75 PassRegistration
<TestLinalgDropUnitDims
>();