Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Mesh / TestOpLowering.cpp
blob1f836be1ae7ac19a51a9f56f64d00f44e30ef746
1 //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 using namespace mlir;
18 namespace {
20 struct TestAllSliceOpLoweringPass
21 : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
24 void runOnOperation() override {
25 RewritePatternSet patterns(&getContext());
26 SymbolTableCollection symbolTableCollection;
27 mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
28 LogicalResult status =
29 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
30 (void)status;
31 assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
33 void getDependentDialects(DialectRegistry &registry) const override {
34 mesh::registerAllSliceOpLoweringDialects(registry);
36 StringRef getArgument() const final {
37 return "test-mesh-all-slice-op-lowering";
39 StringRef getDescription() const final {
40 return "Test lowering of all-slice.";
44 struct TestMultiIndexOpLoweringPass
45 : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
48 void runOnOperation() override {
49 RewritePatternSet patterns(&getContext());
50 SymbolTableCollection symbolTableCollection;
51 mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
52 symbolTableCollection);
53 LogicalResult status =
54 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
55 (void)status;
56 assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
58 void getDependentDialects(DialectRegistry &registry) const override {
59 mesh::registerProcessMultiIndexOpLoweringDialects(registry);
61 StringRef getArgument() const final {
62 return "test-mesh-process-multi-index-op-lowering";
64 StringRef getDescription() const final {
65 return "Test lowering of mesh.process_multi_index op.";
69 } // namespace
71 namespace mlir {
72 namespace test {
73 void registerTestOpLoweringPasses() {
74 PassRegistration<TestAllSliceOpLoweringPass>();
75 PassRegistration<TestMultiIndexOpLoweringPass>();
77 } // namespace test
78 } // namespace mlir