1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
13 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
22 #define DEBUG_TYPE "arith-to-arm-sme"
26 //===----------------------------------------------------------------------===//
28 //===----------------------------------------------------------------------===//
30 /// Returns true if 'val' is a splat of zero, false otherwise.
31 static bool isSplatZero(Type elemType
, DenseElementsAttr val
) {
32 if (llvm::isa
<FloatType
>(elemType
))
33 return val
&& val
.isSplat() && val
.getSplatValue
<APFloat
>().isZero();
34 if (llvm::isa
<IntegerType
>(elemType
))
35 return val
&& val
.isSplat() && val
.getSplatValue
<APInt
>().isZero();
41 //===----------------------------------------------------------------------===//
43 //===----------------------------------------------------------------------===//
45 /// Conversion pattern for dense arith.constant.
46 struct ConstantOpToArmSMELowering
: public OpRewritePattern
<arith::ConstantOp
> {
47 using OpRewritePattern
<arith::ConstantOp
>::OpRewritePattern
;
49 LogicalResult
matchAndRewrite(arith::ConstantOp constantOp
,
50 PatternRewriter
&rewriter
) const final
{
51 auto tileType
= dyn_cast
<VectorType
>(constantOp
.getType());
52 if (!tileType
|| !arm_sme::isValidSMETileVectorType(tileType
))
55 auto denseAttr
= dyn_cast
<DenseElementsAttr
>(constantOp
.getValueAttr());
56 if (!denseAttr
|| !denseAttr
.isSplat())
59 auto tileElementType
= tileType
.getElementType();
61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62 if (isSplatZero(tileElementType
, denseAttr
)) {
63 rewriter
.replaceOpWithNewOp
<arm_sme::ZeroOp
>(constantOp
, tileType
);
67 // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68 // ops that broadcast the constant to each tile slice.
69 auto loc
= constantOp
.getLoc();
71 // To fill a tile with a constant, we create a 1-D splat of the constant,
72 // then move that into each tile slice (the largest unit we can set at once,
73 // outside of operations like the outerproduct).
74 VectorType tileSliceType
= VectorType::Builder(tileType
).dropDim(0);
75 auto denseAttr1D
= DenseElementsAttr::get(
76 tileSliceType
, denseAttr
.getSplatValue
<Attribute
>());
77 auto constantOp1D
= rewriter
.create
<arith::ConstantOp
>(loc
, denseAttr1D
);
79 auto initTile
= rewriter
.create
<arm_sme::GetTileOp
>(loc
, tileType
);
80 auto makeLoopBody
= [&](OpBuilder
&b
, Location loc
, Value tileSliceIndex
,
82 // Create 'arm_sme.insert_tile_slice' to write vector to tile
84 auto nextTile
= b
.create
<arm_sme::InsertTileSliceOp
>(
85 loc
, tileType
, constantOp1D
, currentTile
, tileSliceIndex
);
86 return nextTile
.getResult();
88 auto forOp
= mlir::arm_sme::createLoopOverTileSlices(
89 rewriter
, loc
, initTile
, makeLoopBody
);
90 rewriter
.replaceOp(constantOp
, forOp
.getResult(0));
98 //===----------------------------------------------------------------------===//
100 //===----------------------------------------------------------------------===//
102 void mlir::arith::populateArithToArmSMEConversionPatterns(
103 RewritePatternSet
&patterns
) {
104 patterns
.add
<ConstantOpToArmSMELowering
>(patterns
.getContext());
107 //===----------------------------------------------------------------------===//
109 //===----------------------------------------------------------------------===//
112 struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase
<ArithToArmSMEConversionPass
> {
114 using impl::ArithToArmSMEConversionPassBase
<
115 ArithToArmSMEConversionPass
>::ArithToArmSMEConversionPassBase
;
117 void runOnOperation() override
{
118 RewritePatternSet
patterns(&getContext());
119 arith::populateArithToArmSMEConversionPatterns(patterns
);
120 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns
))))
121 return signalPassFailure();