Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Linalg / TestLinalgDropUnitDims.cpp
blob402ce154c0848e4a110750b112e26a44b96c7d79
1 //===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing the transformation to drop unit
10 // extent dimensions from `linalg.generic` operations.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 using namespace mlir;
22 namespace {
24 LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
25 linalg::GenericOp genericOp) {
26 linalg::ControlDropUnitDims options;
27 options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
28 FailureOr<linalg::DropUnitDimsResult> result =
29 linalg::dropUnitDims(rewriter, genericOp, options);
30 if (failed(result)) {
31 return failure();
33 rewriter.replaceOp(genericOp, result->replacements);
34 return success();
37 struct TestLinalgDropUnitDims
38 : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
42 TestLinalgDropUnitDims() = default;
43 TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
45 void getDependentDialects(DialectRegistry &registry) const override {
46 registry.insert<linalg::LinalgDialect>();
49 StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
51 StringRef getDescriptions() const {
52 return "Test transformation to drop unit-extent dims from Linalg "
53 "operations";
56 void runOnOperation() override {
57 MLIRContext *context = &this->getContext();
58 func::FuncOp funcOp = this->getOperation();
59 IRRewriter rewriter(context);
60 SmallVector<linalg::GenericOp> genericOps;
61 funcOp.walk(
62 [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
64 for (auto genericOp : genericOps) {
65 rewriter.setInsertionPoint(genericOp);
66 (void)dropOutermostUnitDims(rewriter, genericOp);
70 } // namespace
72 namespace mlir {
73 namespace test {
74 void registerTestLinalgDropUnitDims() {
75 PassRegistration<TestLinalgDropUnitDims>();
77 } // namespace test
78 } // namespace mlir