1 //===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11 #include "mlir/Dialect/Utils/IndexingUtils.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 struct TestAllSliceOpLoweringPass
21 : public PassWrapper
<TestAllSliceOpLoweringPass
, OperationPass
<>> {
22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass
)
24 void runOnOperation() override
{
25 RewritePatternSet
patterns(&getContext());
26 SymbolTableCollection symbolTableCollection
;
27 mesh::populateAllSliceOpLoweringPatterns(patterns
, symbolTableCollection
);
28 LogicalResult status
=
29 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
31 assert(succeeded(status
) && "applyPatternsAndFoldGreedily failed.");
33 void getDependentDialects(DialectRegistry
®istry
) const override
{
34 mesh::registerAllSliceOpLoweringDialects(registry
);
36 StringRef
getArgument() const final
{
37 return "test-mesh-all-slice-op-lowering";
39 StringRef
getDescription() const final
{
40 return "Test lowering of all-slice.";
44 struct TestMultiIndexOpLoweringPass
45 : public PassWrapper
<TestMultiIndexOpLoweringPass
, OperationPass
<>> {
46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass
)
48 void runOnOperation() override
{
49 RewritePatternSet
patterns(&getContext());
50 SymbolTableCollection symbolTableCollection
;
51 mesh::populateProcessMultiIndexOpLoweringPatterns(patterns
,
52 symbolTableCollection
);
53 LogicalResult status
=
54 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
56 assert(succeeded(status
) && "applyPatternsAndFoldGreedily failed.");
58 void getDependentDialects(DialectRegistry
®istry
) const override
{
59 mesh::registerProcessMultiIndexOpLoweringDialects(registry
);
61 StringRef
getArgument() const final
{
62 return "test-mesh-process-multi-index-op-lowering";
64 StringRef
getDescription() const final
{
65 return "Test lowering of mesh.process_multi_index op.";
73 void registerTestOpLoweringPasses() {
74 PassRegistration
<TestAllSliceOpLoweringPass
>();
75 PassRegistration
<TestMultiIndexOpLoweringPass
>();