1 //===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements lowering of ArmSME operations to SCF.
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
16 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
22 #define GEN_PASS_DEF_CONVERTARMSMETOSCF
23 #include "mlir/Conversion/Passes.h.inc"
29 /// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
30 /// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
31 /// rank 2: (indices[0] + tileSliceIndex, indices[1])
32 SmallVector
<Value
, 2> getMemrefIndices(ValueRange indices
, unsigned rank
,
34 Value tileSliceNumElts
, Location loc
,
35 PatternRewriter
&rewriter
) {
36 assert((rank
== 1 || rank
== 2) && "memref has unexpected rank!");
37 SmallVector
<Value
, 2> outIndices
;
39 auto tileSliceOffset
= tileSliceIndex
;
42 rewriter
.create
<arith::MulIOp
>(loc
, tileSliceOffset
, tileSliceNumElts
);
44 auto baseIndexPlusTileSliceOffset
=
45 rewriter
.create
<arith::AddIOp
>(loc
, indices
[0], tileSliceOffset
);
46 outIndices
.push_back(baseIndexPlusTileSliceOffset
);
49 outIndices
.push_back(indices
[1]);
54 /// Creates an scf.for for the load/store of an ArmSME tile.
55 FailureOr
<scf::ForOp
> createLoadStoreForOverTileSlices(
56 PatternRewriter
&rewriter
, Location loc
, VectorType tileType
,
57 ValueRange memrefIndices
, int memrefRank
, Value mask
, Value initTile
,
58 function_ref
<Value(/*index=*/Value
, ValueRange
, /*predicate=*/Value
,
59 /*currentTile=*/Value
)>
61 PatternRewriter::InsertionGuard
guard(rewriter
);
63 auto minTileSlices
= rewriter
.create
<arith::ConstantIndexOp
>(
64 loc
, arm_sme::getSMETileSliceMinNumElts(tileType
.getElementType()));
66 rewriter
.create
<vector::VectorScaleOp
>(loc
, rewriter
.getIndexType());
68 VectorType::get(tileType
.getDimSize(1), rewriter
.getI1Type(), true);
70 // This describes both the number of ZA tile slices and the number of
71 // elements in a vector of SVL bits for a given element type (SVL_B,
72 // SVL_H, ..., SVL_Q).
74 rewriter
.create
<arith::MulIOp
>(loc
, minTileSlices
, vscale
);
79 auto createMaskOp
= mask
.getDefiningOp
<vector::CreateMaskOp
>();
81 return rewriter
.notifyMatchFailure(
82 loc
, "unsupported mask op, only 'vector.create_mask' is "
83 "currently supported");
85 auto maskDim0
= createMaskOp
.getOperands()[0];
86 auto maskDim1
= createMaskOp
.getOperands()[1];
88 // The upper bound of the loop must be clamped at `numTileSlices` as
89 // `vector.create_mask` allows operands to be greater than the size of a
91 auto numRowI64
= rewriter
.create
<arith::IndexCastOp
>(
92 loc
, rewriter
.getI64Type(), maskDim0
);
93 auto numTileSlicesI64
= rewriter
.create
<arith::IndexCastOp
>(
94 loc
, rewriter
.getI64Type(), numTileSlices
);
96 rewriter
.create
<arith::MinSIOp
>(loc
, numRowI64
, numTileSlicesI64
);
97 upperBound
= rewriter
.create
<arith::IndexCastOp
>(
98 loc
, rewriter
.getIndexType(), upperBoundI64
);
101 rewriter
.create
<vector::CreateMaskOp
>(loc
, predicateType
, maskDim1
);
103 upperBound
= numTileSlices
;
104 // No mask. Create an 'all true' predicate for the tile slice.
105 predicate
= rewriter
.create
<arith::ConstantOp
>(
106 loc
, DenseElementsAttr::get(predicateType
, true));
109 bool hasCarriedArgs
= bool(initTile
);
110 auto lowerBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
111 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
112 auto forOp
= rewriter
.create
<scf::ForOp
>(loc
, lowerBound
, upperBound
, step
,
113 hasCarriedArgs
? ValueRange
{initTile
}
116 rewriter
.setInsertionPointToStart(forOp
.getBody());
117 Value tileSliceIndex
= forOp
.getInductionVar();
119 auto adjustedIndices
= getMemrefIndices(
120 memrefIndices
, memrefRank
, tileSliceIndex
, numTileSlices
, loc
, rewriter
);
121 auto nextTile
= makeLoopBody(
122 tileSliceIndex
, adjustedIndices
, predicate
,
123 /*currentTile=*/hasCarriedArgs
? forOp
.getRegionIterArg(0) : Value
{});
125 assert(bool(nextTile
) == hasCarriedArgs
);
127 rewriter
.create
<scf::YieldOp
>(loc
, nextTile
);
132 FailureOr
<scf::ForOp
> createLoadStoreForOverTileSlices(
133 PatternRewriter
&rewriter
, Location loc
, VectorType tileType
,
134 ValueRange memrefIndices
, int memrefRank
, Value mask
,
135 function_ref
<void(/*index=*/Value
, ValueRange
, /*predicate=*/Value
)>
137 return createLoadStoreForOverTileSlices(
138 rewriter
, loc
, tileType
, memrefIndices
, memrefRank
, mask
, Value
{},
139 [&](Value index
, ValueRange adjustedIndices
, Value predicate
,
141 makeLoopBody(index
, adjustedIndices
, predicate
);
146 /// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
152 /// %pad = arith.constant 0 : i32
153 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
154 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
155 /// memref<?x?xi32>, vector<[4]x[4]xi32>
160 /// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
161 /// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
162 /// %loop_rows = arith.minsi %num_rows, %svl_s : index
163 /// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
164 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
165 /// %tile_update = arm_sme.load_tile_slice
166 /// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
167 /// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
168 /// scf.yield %tile_update : vector<[4]x[4]xi32>
172 /// Without a mask the lowering is pretty much identical. The only difference is
173 /// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
175 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
176 struct TileLoadOpConversion
: public OpRewritePattern
<arm_sme::TileLoadOp
> {
177 using OpRewritePattern
<arm_sme::TileLoadOp
>::OpRewritePattern
;
179 LogicalResult
matchAndRewrite(arm_sme::TileLoadOp tileLoadOp
,
180 PatternRewriter
&rewriter
) const override
{
181 auto loc
= tileLoadOp
.getLoc();
182 auto tileType
= tileLoadOp
.getVectorType();
183 auto mask
= tileLoadOp
.getMask();
187 auto padOp
= tileLoadOp
.getPadding();
188 assert(padOp
&& "expected padding when masking!");
190 auto constPadOp
= padOp
.getDefiningOp
<arith::ConstantOp
>();
191 if (!constPadOp
|| constPadOp
.getValue() !=
192 rewriter
.getZeroAttr(tileType
.getElementType()))
193 return rewriter
.notifyMatchFailure(
194 tileLoadOp
, "op has non-zero pad, needs non-zero pad pattern");
196 // Initialize tile with zero to satisfy padding. Inactive cols will be
197 // zeroed anyway since the loads use zeroing predication. For inactive
198 // rows however, no load will occur so these need to be zeroed.
199 initTile
= rewriter
.create
<arm_sme::ZeroOp
>(loc
, tileType
);
201 initTile
= rewriter
.create
<arm_sme::GetTileOp
>(loc
, tileType
);
204 // Create a loop to load the active tile slices from memory.
205 auto forOp
= createLoadStoreForOverTileSlices(
206 rewriter
, loc
, tileType
, tileLoadOp
.getIndices(),
207 tileLoadOp
.getMemRefType().getRank(), mask
, initTile
,
208 [&](Value tileSliceIndex
, ValueRange memrefIndices
, Value predicate
,
209 Value currentTile
) -> Value
{
210 // Create 'arm_sme.load_tile_slice' to load tile slice from memory
212 return rewriter
.create
<arm_sme::LoadTileSliceOp
>(
213 loc
, tileType
, tileLoadOp
.getBase(), predicate
, currentTile
,
214 memrefIndices
, tileSliceIndex
, tileLoadOp
.getLayout());
220 // Replace 'arm_sme.tile_load' with the result.
221 rewriter
.replaceOp(tileLoadOp
, forOp
->getResult(0));
227 /// Lower `arm_sme.tile_load` with mask and non-zero pad.
231 /// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
232 /// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
233 /// memref<?x?xi32>, vector<[4]x[4]xi32>
239 /// %pad_1d = vector.splat %pad : vector<[4]xi32>
240 /// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
241 /// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
243 /// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
244 /// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
245 /// : memref<?x?xi32>, vector<[4]xi1>,
246 /// vector<[4]xi32> into vector<[4]xi32>
247 /// // Insert slice into tile
248 /// %tile_update = arm_sme.insert_tile_slice
249 /// %slice, %iter_tile[%tile_slice_idx] :
250 /// vector<[4]xi32> into vector<[4]x[4]xi32>
251 /// scf.yield %tile_update : vector<[4]x[4]xi32>
254 struct TileLoadOpWithMaskAndPadNonZeroConversion
255 : public OpRewritePattern
<arm_sme::TileLoadOp
> {
256 using OpRewritePattern
<arm_sme::TileLoadOp
>::OpRewritePattern
;
258 LogicalResult
matchAndRewrite(arm_sme::TileLoadOp tileLoadOp
,
259 PatternRewriter
&rewriter
) const override
{
260 OpBuilder::InsertionGuard
g(rewriter
);
261 auto loc
= tileLoadOp
.getLoc();
262 auto tileType
= tileLoadOp
.getVectorType();
263 auto tileElementType
= tileType
.getElementType();
265 auto maskOp
= tileLoadOp
.getMask();
267 return rewriter
.notifyMatchFailure(
268 tileLoadOp
, "op has no mask, needs unmasked pattern");
270 auto padOp
= tileLoadOp
.getPadding();
271 assert(padOp
&& "expected padding when masking!");
273 auto createMaskOp
= maskOp
.getDefiningOp
<vector::CreateMaskOp
>();
275 return rewriter
.notifyMatchFailure(
276 tileLoadOp
, "unsupported mask op, only 'vector.create_mask' is "
277 "currently supported");
279 auto constPadOp
= padOp
.getDefiningOp
<arith::ConstantOp
>();
281 constPadOp
.getValue() == rewriter
.getZeroAttr(tileElementType
))
282 return rewriter
.notifyMatchFailure(
283 tileLoadOp
, "op has constant zero pad, needs zero pad pattern");
285 auto numRows
= createMaskOp
.getOperands()[0];
286 auto numCols
= createMaskOp
.getOperands()[1];
288 auto numColsI32
= rewriter
.create
<arith::IndexCastUIOp
>(
289 loc
, rewriter
.getI32Type(), numCols
);
291 auto initTile
= rewriter
.create
<arm_sme::GetTileOp
>(loc
, tileType
);
293 // Create a loop that loads each ZA tile slice from memory.
294 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
295 auto minTileSlices
= rewriter
.create
<arith::ConstantIndexOp
>(
296 loc
, arm_sme::getSMETileSliceMinNumElts(tileElementType
));
298 rewriter
.create
<vector::VectorScaleOp
>(loc
, rewriter
.getIndexType());
299 auto lowerBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
301 rewriter
.create
<arith::MulIOp
>(loc
, minTileSlices
, vscale
);
302 auto forOp
= rewriter
.create
<scf::ForOp
>(loc
, lowerBound
, numTileSlices
,
303 step
, ValueRange
{initTile
});
305 rewriter
.setInsertionPointToStart(forOp
.getBody());
307 auto tileSliceIndex
= forOp
.getInductionVar();
308 auto currentTile
= forOp
.getRegionIterArg(0);
311 auto rowIsActive
= rewriter
.create
<arith::CmpIOp
>(
312 loc
, arith::CmpIPredicate::ult
, tileSliceIndex
, numRows
);
313 auto rowIsActiveI32
= rewriter
.create
<arith::ExtSIOp
>(
314 loc
, rewriter
.getI32Type(), rowIsActive
);
315 auto mask
= rewriter
.create
<arith::AndIOp
>(loc
, rowIsActiveI32
, numColsI32
);
317 rewriter
.create
<arith::IndexCastOp
>(loc
, rewriter
.getIndexType(), mask
);
319 VectorType::get(tileType
.getDimSize(1), rewriter
.getI1Type(), true);
320 auto maskOp1D
= rewriter
.create
<vector::CreateMaskOp
>(
321 loc
, predicateType
, maskIndex
.getResult());
323 auto memrefIndices
= getMemrefIndices(
324 tileLoadOp
.getIndices(), tileLoadOp
.getMemRefType().getRank(),
325 tileSliceIndex
, numTileSlices
, loc
, rewriter
);
327 // Splat pad into 1-D vector matching type of tile slice.
328 VectorType tileSliceType
= VectorType::Builder(tileType
).dropDim(0);
329 auto pad1DOp
= rewriter
.create
<vector::SplatOp
>(loc
, tileSliceType
, padOp
);
331 auto loadSlice
= rewriter
.create
<vector::MaskedLoadOp
>(
332 loc
, tileSliceType
, tileLoadOp
.getBase(), memrefIndices
, maskOp1D
,
333 /*passthru=*/pad1DOp
);
335 // Create 'arm_sme.insert_tile_slice' to insert slice into tile.
336 auto insertSlice
= rewriter
.create
<arm_sme::InsertTileSliceOp
>(
337 loc
, tileType
, loadSlice
->getResult(0), currentTile
, tileSliceIndex
,
338 tileLoadOp
.getLayout());
339 rewriter
.create
<scf::YieldOp
>(loc
, insertSlice
.getResult());
341 rewriter
.setInsertionPointAfter(forOp
);
343 // Replace 'arm_sme.tile_load' with the result.
344 rewriter
.replaceOp(tileLoadOp
, forOp
.getResult(0));
350 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
351 /// slice using `arm_sme.store_tile_slice`.
355 /// arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
356 /// : memref<?x?xi32>, vector<[4]x[4]xi32
361 /// %vscale = vector.vscale
362 /// %c0 = arith.constant 0 : index
363 /// %c1 = arith.constant 1 : index
364 /// %min_svl_s = arith.constant 4 : index
365 /// %svl_s = arith.muli %min_svl_s, %vscale : index
366 /// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
367 /// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
368 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
371 struct TileStoreOpConversion
: public OpRewritePattern
<arm_sme::TileStoreOp
> {
372 using OpRewritePattern
<arm_sme::TileStoreOp
>::OpRewritePattern
;
374 LogicalResult
matchAndRewrite(arm_sme::TileStoreOp tileStoreOp
,
375 PatternRewriter
&rewriter
) const override
{
376 // Create a loop that stores each active ZA tile slice from memory.
377 return createLoadStoreForOverTileSlices(
378 rewriter
, tileStoreOp
.getLoc(), tileStoreOp
.getVectorType(),
379 tileStoreOp
.getIndices(), tileStoreOp
.getMemRefType().getRank(),
380 tileStoreOp
.getMask(),
381 [&](Value tileSliceIndex
, ValueRange memrefIndices
, Value predicate
) {
382 rewriter
.replaceOpWithNewOp
<arm_sme::StoreTileSliceOp
>(
383 tileStoreOp
, tileStoreOp
.getValueToStore(), tileSliceIndex
,
384 predicate
, tileStoreOp
.getBase(), memrefIndices
,
385 tileStoreOp
.getLayout());
392 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet
&patterns
) {
393 patterns
.add
<TileLoadOpConversion
, TileLoadOpWithMaskAndPadNonZeroConversion
,
394 TileStoreOpConversion
>(patterns
.getContext());
399 struct ConvertArmSMEToSCFPass
400 : public impl::ConvertArmSMEToSCFBase
<ConvertArmSMEToSCFPass
> {
401 void runOnOperation() override
{
402 RewritePatternSet
patterns(&getContext());
403 ConversionTarget
target(getContext());
404 populateArmSMEToSCFConversionPatterns(patterns
);
405 target
.addLegalDialect
<arm_sme::ArmSMEDialect
, vector::VectorDialect
,
406 arith::ArithDialect
, scf::SCFDialect
>();
407 target
.addIllegalOp
<arm_sme::TileLoadOp
, arm_sme::TileStoreOp
>();
408 if (failed(applyPartialConversion(getOperation(), target
,
409 std::move(patterns
))))
416 std::unique_ptr
<Pass
> mlir::createConvertArmSMEToSCFPass() {
417 return std::make_unique
<ConvertArmSMEToSCFPass
>();