Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Mesh / TestReshardingSpmdization.cpp
blob98992c4cc11f92f57f69d7738f96368cf385fab4
1 //===- TestSimplification.cpp - Test simplification -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
11 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
12 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypeInterfaces.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 using namespace mlir;
25 using namespace mlir::mesh;
27 namespace {
29 struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
30 using OpRewritePattern<ShardOp>::OpRewritePattern;
32 LogicalResult matchAndRewrite(ShardOp op,
33 PatternRewriter &rewriter) const override {
34 if (op.getAnnotateForUsers()) {
35 return failure();
38 SymbolTableCollection symbolTable;
39 mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
40 op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
42 bool foundUser = false;
43 for (auto user : op->getUsers()) {
44 if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
45 if (targetShardOp.getAnnotateForUsers() &&
46 mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
47 targetShardOp,
48 cast<ShardingOp>(
49 targetShardOp.getSharding().getDefiningOp())
50 .getMeshAttr())) {
51 foundUser = true;
52 break;
57 if (!foundUser) {
58 return failure();
61 for (auto user : op->getUsers()) {
62 auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
63 if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
64 symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
65 targetShardOp,
66 cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
67 .getMeshAttr()) != mesh) {
68 continue;
71 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
72 ShapedType sourceShardShape =
73 shardShapedType(op.getResult().getType(), mesh, op.getSharding());
74 TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
75 builder
76 .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
77 ->getResult(0));
78 TypedValue<ShapedType> targetShard =
79 reshard(builder, mesh, op, targetShardOp, sourceShard);
80 Value newTargetUnsharded =
81 builder
82 .create<UnrealizedConversionCastOp>(
83 targetShardOp.getResult().getType(), targetShard)
84 ->getResult(0);
85 rewriter.replaceAllUsesWith(targetShardOp.getResult(),
86 newTargetUnsharded);
89 return success();
93 struct TestMeshReshardingPass
94 : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
97 void runOnOperation() override {
98 RewritePatternSet patterns(&getContext());
99 patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
100 if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
101 std::move(patterns)))) {
102 return signalPassFailure();
105 void getDependentDialects(DialectRegistry &registry) const override {
106 reshardingRegisterDependentDialects(registry);
107 registry.insert<BuiltinDialect>();
109 StringRef getArgument() const final {
110 return "test-mesh-resharding-spmdization";
112 StringRef getDescription() const final {
113 return "Test Mesh dialect resharding spmdization.";
116 } // namespace
118 namespace mlir {
119 namespace test {
120 void registerTestMeshReshardingSpmdizationPass() {
121 PassRegistration<TestMeshReshardingPass>();
123 } // namespace test
124 } // namespace mlir