1 //===- TestSimplification.cpp - Test simplification -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
11 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
12 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypeInterfaces.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 using namespace mlir::mesh
;
29 struct TestMeshReshardingRewritePattern
: OpRewritePattern
<ShardOp
> {
30 using OpRewritePattern
<ShardOp
>::OpRewritePattern
;
32 LogicalResult
matchAndRewrite(ShardOp op
,
33 PatternRewriter
&rewriter
) const override
{
34 if (op
.getAnnotateForUsers()) {
38 SymbolTableCollection symbolTable
;
39 mesh::MeshOp mesh
= symbolTable
.lookupNearestSymbolFrom
<mesh::MeshOp
>(
40 op
, cast
<ShardingOp
>(op
.getSharding().getDefiningOp()).getMeshAttr());
42 bool foundUser
= false;
43 for (auto user
: op
->getUsers()) {
44 if (auto targetShardOp
= llvm::dyn_cast
<ShardOp
>(user
)) {
45 if (targetShardOp
.getAnnotateForUsers() &&
46 mesh
== symbolTable
.lookupNearestSymbolFrom
<mesh::MeshOp
>(
49 targetShardOp
.getSharding().getDefiningOp())
61 for (auto user
: op
->getUsers()) {
62 auto targetShardOp
= llvm::dyn_cast
<ShardOp
>(user
);
63 if (!targetShardOp
|| !targetShardOp
.getAnnotateForUsers() ||
64 symbolTable
.lookupNearestSymbolFrom
<mesh::MeshOp
>(
66 cast
<ShardingOp
>(targetShardOp
.getSharding().getDefiningOp())
67 .getMeshAttr()) != mesh
) {
71 ImplicitLocOpBuilder
builder(op
->getLoc(), rewriter
);
72 ShapedType sourceShardShape
=
73 shardShapedType(op
.getResult().getType(), mesh
, op
.getSharding());
74 TypedValue
<ShapedType
> sourceShard
= cast
<TypedValue
<ShapedType
>>(
76 .create
<UnrealizedConversionCastOp
>(sourceShardShape
, op
.getSrc())
78 TypedValue
<ShapedType
> targetShard
=
79 reshard(builder
, mesh
, op
, targetShardOp
, sourceShard
);
80 Value newTargetUnsharded
=
82 .create
<UnrealizedConversionCastOp
>(
83 targetShardOp
.getResult().getType(), targetShard
)
85 rewriter
.replaceAllUsesWith(targetShardOp
.getResult(),
93 struct TestMeshReshardingPass
94 : public PassWrapper
<TestMeshReshardingPass
, OperationPass
<ModuleOp
>> {
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass
)
97 void runOnOperation() override
{
98 RewritePatternSet
patterns(&getContext());
99 patterns
.insert
<TestMeshReshardingRewritePattern
>(&getContext());
100 if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
101 std::move(patterns
)))) {
102 return signalPassFailure();
105 void getDependentDialects(DialectRegistry
®istry
) const override
{
106 reshardingRegisterDependentDialects(registry
);
107 registry
.insert
<BuiltinDialect
>();
109 StringRef
getArgument() const final
{
110 return "test-mesh-resharding-spmdization";
112 StringRef
getDescription() const final
{
113 return "Test Mesh dialect resharding spmdization.";
120 void registerTestMeshReshardingSpmdizationPass() {
121 PassRegistration
<TestMeshReshardingPass
>();