[lldb] Remove some unused code in SymbolFileDWARF::ResolveFunction (#123206)
[llvm-project.git] / mlir / lib / Conversion / ArithToArmSME / ArithToArmSME.cpp
blobcbe0b3fda3410d861c5570b2bec71dd927e701a4
1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
13 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 namespace mlir {
18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
20 } // namespace mlir
22 #define DEBUG_TYPE "arith-to-arm-sme"
24 using namespace mlir;
26 //===----------------------------------------------------------------------===//
27 // Conversion helpers
28 //===----------------------------------------------------------------------===//
30 /// Returns true if 'val' is a splat of zero, false otherwise.
31 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32 if (llvm::isa<FloatType>(elemType))
33 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34 if (llvm::isa<IntegerType>(elemType))
35 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36 return false;
39 namespace {
41 //===----------------------------------------------------------------------===//
42 // ConstantOp
43 //===----------------------------------------------------------------------===//
45 /// Conversion pattern for dense arith.constant.
46 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47 using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
49 LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50 PatternRewriter &rewriter) const final {
51 auto tileType = dyn_cast<VectorType>(constantOp.getType());
52 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53 return failure();
55 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56 if (!denseAttr || !denseAttr.isSplat())
57 return failure();
59 auto tileElementType = tileType.getElementType();
61 // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62 if (isSplatZero(tileElementType, denseAttr)) {
63 rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64 return success();
67 // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68 // ops that broadcast the constant to each tile slice.
69 auto loc = constantOp.getLoc();
71 // To fill a tile with a constant, we create a 1-D splat of the constant,
72 // then move that into each tile slice (the largest unit we can set at once,
73 // outside of operations like the outerproduct).
74 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75 auto denseAttr1D = DenseElementsAttr::get(
76 tileSliceType, denseAttr.getSplatValue<Attribute>());
77 auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
79 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81 Value currentTile) {
82 // Create 'arm_sme.insert_tile_slice' to write vector to tile
83 // slice.
84 auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
85 loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86 return nextTile.getResult();
88 auto forOp = mlir::arm_sme::createLoopOverTileSlices(
89 rewriter, loc, initTile, makeLoopBody);
90 rewriter.replaceOp(constantOp, forOp.getResult(0));
92 return success();
96 } // namespace
98 //===----------------------------------------------------------------------===//
99 // Pattern population
100 //===----------------------------------------------------------------------===//
102 void mlir::arith::populateArithToArmSMEConversionPatterns(
103 RewritePatternSet &patterns) {
104 patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
107 //===----------------------------------------------------------------------===//
108 // Pass definition
109 //===----------------------------------------------------------------------===//
111 namespace {
112 struct ArithToArmSMEConversionPass final
113 : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114 using impl::ArithToArmSMEConversionPassBase<
115 ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
117 void runOnOperation() override {
118 RewritePatternSet patterns(&getContext());
119 arith::populateArithToArmSMEConversionPatterns(patterns);
120 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
121 return signalPassFailure();
124 } // namespace