1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
11 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
13 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/Support/Casting.h"
22 /// Conversion pattern for vector.transfer_read.
26 /// Example 1: op with identity permutation map to horizontal
27 /// arm_sme.tile_load:
29 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
33 /// arm_sme.tile_load ...
37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38 /// (in-flight transpose):
40 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
44 /// arm_sme.tile_load ... layout<vertical>
45 struct TransferReadToArmSMELowering
46 : public OpRewritePattern
<vector::TransferReadOp
> {
47 using OpRewritePattern
<vector::TransferReadOp
>::OpRewritePattern
;
49 LogicalResult
matchAndRewrite(vector::TransferReadOp transferReadOp
,
50 PatternRewriter
&rewriter
) const final
{
51 // The permutation map must have two results.
52 if (transferReadOp
.getTransferRank() != 2)
53 return rewriter
.notifyMatchFailure(transferReadOp
,
54 "not a 2 result permutation map");
56 auto vectorType
= transferReadOp
.getVectorType();
57 if (!arm_sme::isValidSMETileVectorType(vectorType
))
58 return rewriter
.notifyMatchFailure(transferReadOp
,
59 "not a valid vector type for SME");
61 if (!llvm::isa
<MemRefType
>(transferReadOp
.getSource().getType()))
62 return rewriter
.notifyMatchFailure(transferReadOp
, "not a memref source");
64 // Out-of-bounds dims are not supported.
65 if (transferReadOp
.hasOutOfBoundsDim())
66 return rewriter
.notifyMatchFailure(transferReadOp
,
67 "not inbounds transfer read");
69 AffineMap map
= transferReadOp
.getPermutationMap();
70 if (!map
.isPermutation())
71 return rewriter
.notifyMatchFailure(transferReadOp
,
72 "unsupported permutation map");
74 // Note: For 2D vector types the only non-identity permutation is a simple
76 bool transposed
= !map
.isIdentity();
77 arm_sme::TileSliceLayout layout
=
78 transposed
? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal
;
81 // Padding isn't optional for transfer_read, but is only used in the case
82 // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83 // optional, if it's not present don't pass padding.
84 auto mask
= transferReadOp
.getMask();
85 auto padding
= mask
? transferReadOp
.getPadding() : nullptr;
86 rewriter
.replaceOpWithNewOp
<arm_sme::TileLoadOp
>(
87 transferReadOp
, vectorType
, transferReadOp
.getSource(),
88 transferReadOp
.getIndices(), padding
, mask
, layout
);
94 /// Conversion pattern for vector.transfer_write.
98 /// Example 1: op with identity permutation map to horizontal
99 /// arm_sme.tile_store:
101 /// vector.transfer_write %vector, %source[%c0, %c0]
102 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
106 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107 /// vector<[16]x[16]xi8>
110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111 /// (in-flight transpose):
113 /// vector.transfer_write %vector, %source[%c0, %c0]
114 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
119 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
121 struct TransferWriteToArmSMELowering
122 : public OpRewritePattern
<vector::TransferWriteOp
> {
123 using OpRewritePattern
<vector::TransferWriteOp
>::OpRewritePattern
;
125 LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp
,
126 PatternRewriter
&rewriter
) const final
{
127 auto vType
= writeOp
.getVectorType();
128 if (!arm_sme::isValidSMETileVectorType(vType
))
131 if (!llvm::isa
<MemRefType
>(writeOp
.getSource().getType()))
134 // Out-of-bounds dims are not supported.
135 if (writeOp
.hasOutOfBoundsDim())
136 return rewriter
.notifyMatchFailure(writeOp
,
137 "not inbounds transfer write");
139 AffineMap map
= writeOp
.getPermutationMap();
140 if (!map
.isPermutation())
141 return rewriter
.notifyMatchFailure(writeOp
,
142 "unsupported permutation map");
144 // Note: For 2D vector types the only non-identity permutation is a simple
146 bool transposed
= !map
.isIdentity();
147 arm_sme::TileSliceLayout layout
=
148 transposed
? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal
;
151 rewriter
.replaceOpWithNewOp
<arm_sme::TileStoreOp
>(
152 writeOp
, writeOp
.getVector(), writeOp
.getSource(), writeOp
.getIndices(),
153 writeOp
.getMask(), layout
);
158 /// Conversion pattern for vector.load.
159 struct VectorLoadToArmSMELowering
: public OpRewritePattern
<vector::LoadOp
> {
160 using OpRewritePattern
<vector::LoadOp
>::OpRewritePattern
;
162 LogicalResult
matchAndRewrite(vector::LoadOp load
,
163 PatternRewriter
&rewriter
) const override
{
164 if (!arm_sme::isValidSMETileVectorType(load
.getVectorType()))
167 rewriter
.replaceOpWithNewOp
<arm_sme::TileLoadOp
>(
168 load
, load
.getVectorType(), load
.getBase(), load
.getIndices());
174 /// Conversion pattern for vector.store.
175 struct VectorStoreToArmSMELowering
: public OpRewritePattern
<vector::StoreOp
> {
176 using OpRewritePattern
<vector::StoreOp
>::OpRewritePattern
;
178 LogicalResult
matchAndRewrite(vector::StoreOp store
,
179 PatternRewriter
&rewriter
) const override
{
180 if (!arm_sme::isValidSMETileVectorType(store
.getVectorType()))
183 rewriter
.replaceOpWithNewOp
<arm_sme::TileStoreOp
>(
184 store
, store
.getValueToStore(), store
.getBase(), store
.getIndices());
190 /// Conversion pattern for vector.broadcast.
194 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
198 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
202 /// %tile_update = arm_sme.insert_tile_slice
203 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204 /// vector<[4]xi32> into vector<[4]x[4]xi32>
205 /// scf.yield %tile_update : vector<[4]x[4]xi32>
208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209 struct BroadcastOpToArmSMELowering
210 : public OpRewritePattern
<vector::BroadcastOp
> {
211 using OpRewritePattern
<vector::BroadcastOp
>::OpRewritePattern
;
213 LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp
,
214 PatternRewriter
&rewriter
) const final
{
215 auto tileType
= broadcastOp
.getResultVectorType();
216 if (!tileType
|| !arm_sme::isValidSMETileVectorType(tileType
))
219 auto loc
= broadcastOp
.getLoc();
221 auto srcType
= broadcastOp
.getSourceType();
222 auto srcVectorType
= dyn_cast
<VectorType
>(srcType
);
225 if (srcType
.isIntOrFloat() ||
226 (srcVectorType
&& (srcVectorType
.getRank() == 0))) {
227 // Broadcast scalar or 0-d vector to 1-d vector.
228 VectorType tileSliceType
= VectorType::Builder(tileType
).dropDim(0);
229 broadcastOp1D
= rewriter
.create
<vector::BroadcastOp
>(
230 loc
, tileSliceType
, broadcastOp
.getSource());
231 } else if (srcVectorType
&& (srcVectorType
.getRank() == 1))
232 // Value to broadcast is already a 1-d vector, nothing to do.
233 broadcastOp1D
= broadcastOp
.getSource();
237 auto initTile
= rewriter
.create
<arm_sme::GetTileOp
>(loc
, tileType
);
239 auto makeLoopBody
= [&](OpBuilder
&b
, Location loc
, Value tileSliceIndex
,
241 // Create 'arm_sme.insert_tile_slice' to broadcast the value
242 // to each tile slice.
243 auto nextTile
= b
.create
<arm_sme::InsertTileSliceOp
>(
244 loc
, tileType
, broadcastOp1D
, currentTile
, tileSliceIndex
);
245 return nextTile
.getResult();
248 // Create a loop over ZA tile slices.
250 createLoopOverTileSlices(rewriter
, loc
, initTile
, makeLoopBody
);
252 rewriter
.replaceOp(broadcastOp
, forOp
.getResult(0));
258 /// Conversion pattern for vector.splat.
262 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
266 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
270 /// %tile_update = arm_sme.insert_tile_slice
271 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272 /// vector<[4]xi32> into vector<[4]x[4]xi32>
273 /// scf.yield %tile_update : vector<[4]x[4]xi32>
276 /// This is identical to vector.broadcast of a scalar.
277 struct SplatOpToArmSMELowering
: public OpRewritePattern
<vector::SplatOp
> {
278 using OpRewritePattern
<vector::SplatOp
>::OpRewritePattern
;
280 LogicalResult
matchAndRewrite(vector::SplatOp splatOp
,
281 PatternRewriter
&rewriter
) const final
{
282 auto tileType
= splatOp
.getResult().getType();
283 if (!tileType
|| !arm_sme::isValidSMETileVectorType(tileType
))
286 auto loc
= splatOp
.getLoc();
287 auto srcType
= splatOp
.getOperand().getType();
289 assert(srcType
.isIntOrFloat() && "Invalid source type for vector.splat");
290 // Avoid unused-variable warning when building without assertions.
293 // First, broadcast the scalar to a 1-d vector.
294 VectorType tileSliceType
= VectorType::Builder(tileType
).dropDim(0);
295 Value broadcastOp1D
= rewriter
.create
<vector::BroadcastOp
>(
296 loc
, tileSliceType
, splatOp
.getInput());
298 auto initTile
= rewriter
.create
<arm_sme::GetTileOp
>(loc
, tileType
);
300 auto makeLoopBody
= [&](OpBuilder
&b
, Location loc
, Value tileSliceIndex
,
302 auto nextTile
= b
.create
<arm_sme::InsertTileSliceOp
>(
303 loc
, tileType
, broadcastOp1D
, currentTile
, tileSliceIndex
);
304 return nextTile
.getResult();
307 // Next, create a loop over ZA tile slices and "move" the generated 1-d
308 // vector to each slice.
310 createLoopOverTileSlices(rewriter
, loc
, initTile
, makeLoopBody
);
312 rewriter
.replaceOp(splatOp
, forOp
.getResult(0));
318 /// Conversion pattern for vector.transpose.
320 /// Stores the input tile to memory and reloads vertically.
324 /// %transposed_src = vector.transpose %src, [1, 0]
325 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
329 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
332 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
335 /// NOTE: Tranposing via memory is obviously expensive, the current intention
336 /// is to avoid the transpose if possible, this is therefore intended as a
337 /// fallback and to provide base support for Vector ops. If it turns out
338 /// transposes can't be avoided then this should be replaced with a more optimal
339 /// implementation, perhaps with tile <-> vector (MOVA) ops.
340 struct TransposeOpToArmSMELowering
341 : public OpRewritePattern
<vector::TransposeOp
> {
342 using OpRewritePattern
<vector::TransposeOp
>::OpRewritePattern
;
344 LogicalResult
matchAndRewrite(vector::TransposeOp transposeOp
,
345 PatternRewriter
&rewriter
) const final
{
346 auto tileType
= transposeOp
.getResultVectorType();
347 if (!tileType
|| !arm_sme::isValidSMETileVectorType(tileType
))
350 // Bail unless this is a true 2-D matrix transpose.
351 ArrayRef
<int64_t> permutation
= transposeOp
.getPermutation();
352 if (permutation
[0] != 1 || permutation
[1] != 0)
355 auto loc
= transposeOp
.getLoc();
356 Value input
= transposeOp
.getVector();
358 if (auto xferOp
= input
.getDefiningOp
<vector::TransferReadOp
>();
359 xferOp
&& xferOp
->hasOneUse()) {
360 // Fold transpose into transfer_read to enable in-flight transpose when
361 // converting to arm_sme.tile_load.
362 rewriter
.modifyOpInPlace(xferOp
, [&]() {
363 xferOp
->setAttr(xferOp
.getPermutationMapAttrName(),
364 AffineMapAttr::get(AffineMap::getPermutationMap(
365 permutation
, transposeOp
.getContext())));
367 rewriter
.replaceOp(transposeOp
, xferOp
);
371 // Allocate buffer to store input tile to.
373 rewriter
.create
<vector::VectorScaleOp
>(loc
, rewriter
.getIndexType());
374 Value minTileSlices
= rewriter
.create
<arith::ConstantOp
>(
375 loc
, rewriter
.getIndexAttr(tileType
.getDimSize(0)));
377 rewriter
.create
<arith::ConstantOp
>(loc
, rewriter
.getIndexAttr(0));
378 Value numTileSlices
=
379 rewriter
.create
<arith::MulIOp
>(loc
, vscale
, minTileSlices
);
381 MemRefType::get({ShapedType::kDynamic
, ShapedType::kDynamic
},
382 tileType
.getElementType());
383 auto buffer
= rewriter
.create
<memref::AllocaOp
>(
384 loc
, bufferType
, ValueRange
{numTileSlices
, numTileSlices
});
387 auto tileStoreOp
= rewriter
.create
<arm_sme::TileStoreOp
>(
388 loc
, input
, buffer
, ValueRange
{c0
, c0
});
390 // Reload input tile vertically.
391 rewriter
.replaceOpWithNewOp
<arm_sme::TileLoadOp
>(
392 transposeOp
, tileType
, tileStoreOp
.getBase(), tileStoreOp
.getIndices(),
393 arm_sme::TileSliceLayout::Vertical
);
399 /// Conversion pattern for vector.outerproduct.
401 /// If the vector.outerproduct is masked (and the mask is from a
402 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
407 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408 /// %result = vector.mask %mask {
409 /// vector.outerproduct %vecA, %vecB
410 /// : vector<[4]xf32>, vector<[4]xf32>
411 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
415 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
416 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
417 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418 /// : vector<[4]xf32>, vector<[4]xf32>
420 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
424 /// %result = vector.outerproduct %vecA, %vecB
425 /// : vector<[4]xf32>, vector<[4]xf32>
429 /// %result = arm_sme.outerproduct %vecA, %vecB
430 /// : vector<[4]xf32>, vector<[4]xf32>
432 struct VectorOuterProductToArmSMELowering
433 : public OpRewritePattern
<vector::OuterProductOp
> {
435 using OpRewritePattern
<vector::OuterProductOp
>::OpRewritePattern
;
437 LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp
,
438 PatternRewriter
&rewriter
) const override
{
440 // We don't yet support lowering AXPY operations to SME. These could be
441 // lowered by masking out all but the first element of the LHS.
442 if (!isa
<VectorType
>(outerProductOp
.getOperandTypeRHS()))
443 return rewriter
.notifyMatchFailure(outerProductOp
,
444 "AXPY operations not supported");
446 if (!arm_sme::isValidSMETileVectorType(
447 outerProductOp
.getResultVectorType()))
448 return rewriter
.notifyMatchFailure(
449 outerProductOp
, "outer product does not fit into SME tile");
451 auto kind
= outerProductOp
.getKind();
452 if (kind
!= vector::CombiningKind::ADD
)
453 return rewriter
.notifyMatchFailure(
455 "unsupported kind (lowering to SME only supports ADD at the moment)");
459 Operation
*rootOp
= outerProductOp
;
460 auto loc
= outerProductOp
.getLoc();
461 if (outerProductOp
.isMasked()) {
462 auto maskOp
= outerProductOp
.getMaskingOp();
463 rewriter
.setInsertionPoint(maskOp
);
465 auto operandMasks
= decomposeResultMask(loc
, maskOp
.getMask(), rewriter
);
466 if (failed(operandMasks
))
468 std::tie(lhsMask
, rhsMask
) = *operandMasks
;
471 rewriter
.replaceOpWithNewOp
<arm_sme::OuterProductOp
>(
472 rootOp
, outerProductOp
.getResultVectorType(), outerProductOp
.getLhs(),
473 outerProductOp
.getRhs(), lhsMask
, rhsMask
, outerProductOp
.getAcc());
478 static FailureOr
<std::pair
<Value
, Value
>>
479 decomposeResultMask(Location loc
, Value mask
, PatternRewriter
&rewriter
) {
480 // Attempt to extract masks from vector.create_mask.
481 // TODO: Add support for other mask sources.
482 auto createMaskOp
= mask
.getDefiningOp
<vector::CreateMaskOp
>();
486 auto maskType
= createMaskOp
.getVectorType();
487 Value lhsMaskDim
= createMaskOp
.getOperand(0);
488 Value rhsMaskDim
= createMaskOp
.getOperand(1);
490 VectorType operandMaskType
= VectorType::Builder(maskType
).dropDim(0);
492 rewriter
.create
<vector::CreateMaskOp
>(loc
, operandMaskType
, lhsMaskDim
);
494 rewriter
.create
<vector::CreateMaskOp
>(loc
, operandMaskType
, rhsMaskDim
);
496 return std::make_pair(lhsMask
, rhsMask
);
500 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
504 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
508 /// %slice = arm_sme.extract_tile_slice %tile[%row]
509 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
510 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
512 struct VectorExtractToArmSMELowering
513 : public OpRewritePattern
<vector::ExtractOp
> {
514 using OpRewritePattern
<vector::ExtractOp
>::OpRewritePattern
;
516 LogicalResult
matchAndRewrite(vector::ExtractOp extractOp
,
517 PatternRewriter
&rewriter
) const override
{
518 VectorType sourceType
= extractOp
.getSourceVectorType();
519 if (!arm_sme::isValidSMETileVectorType(sourceType
))
522 auto loc
= extractOp
.getLoc();
523 auto position
= extractOp
.getMixedPosition();
525 Value sourceVector
= extractOp
.getVector();
527 // Extract entire vector. Should be handled by folder, but just to be safe.
528 if (position
.empty()) {
529 rewriter
.replaceOp(extractOp
, sourceVector
);
533 Value sliceIndex
= vector::getAsValues(rewriter
, loc
, position
[0]).front();
534 auto extractTileSlice
= rewriter
.create
<arm_sme::ExtractTileSliceOp
>(
535 loc
, sourceVector
, sliceIndex
);
537 if (position
.size() == 1) {
538 // Single index case: Extracts a 1D slice.
539 rewriter
.replaceOp(extractOp
, extractTileSlice
);
543 // Two indices case: Extracts a single element.
544 assert(position
.size() == 2);
545 rewriter
.replaceOpWithNewOp
<vector::ExtractOp
>(extractOp
, extractTileSlice
,
552 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553 /// `arm_sme.extract_tile_slice`.
557 /// %new_tile = vector.insert %el, %tile[%row, %col]
558 /// : i32 into vector<[4]x[4]xi32>
562 /// %slice = arm_sme.extract_tile_slice %tile[%row]
563 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
564 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
568 struct VectorInsertToArmSMELowering
569 : public OpRewritePattern
<vector::InsertOp
> {
570 using OpRewritePattern
<vector::InsertOp
>::OpRewritePattern
;
572 LogicalResult
matchAndRewrite(vector::InsertOp insertOp
,
573 PatternRewriter
&rewriter
) const override
{
574 VectorType resultType
= insertOp
.getResult().getType();
576 if (!arm_sme::isValidSMETileVectorType(resultType
))
579 auto loc
= insertOp
.getLoc();
580 auto position
= insertOp
.getMixedPosition();
582 Value source
= insertOp
.getSource();
584 // Overwrite entire vector with value. Should be handled by folder, but
586 if (position
.empty()) {
587 rewriter
.replaceOp(insertOp
, source
);
591 Value tileSlice
= source
;
592 Value sliceIndex
= vector::getAsValues(rewriter
, loc
, position
[0]).front();
593 if (position
.size() == 2) {
594 // Two indices case: Insert single element into tile.
595 // We need to first extract the existing slice and update the element.
596 tileSlice
= rewriter
.create
<arm_sme::ExtractTileSliceOp
>(
597 loc
, insertOp
.getDest(), sliceIndex
);
598 tileSlice
= rewriter
.create
<vector::InsertOp
>(loc
, source
, tileSlice
,
602 // Insert the slice into the destination tile.
603 rewriter
.replaceOpWithNewOp
<arm_sme::InsertTileSliceOp
>(
604 insertOp
, tileSlice
, insertOp
.getDest(), sliceIndex
);
609 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610 /// extracting them via `arm_sme.extract_tile_slice`, then printing with
611 /// a 1D `vector.print`.
615 /// vector.print %tile : vector<[4]x[4]xf32>
619 /// %c0 = arith.constant 0 : index
620 /// %c1 = arith.constant 1 : index
621 /// %c4 = arith.constant 4 : index
622 /// %vscale = vector.vscale
623 /// %svl_s = arith.muli %c4, %vscale : index
624 /// scf.for %i = %c0 to %svl_s step %c1 {
625 /// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
626 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
627 /// vector.print %tile_slice : vector<[4]xf32>
630 struct VectorPrintToArmSMELowering
: public OpRewritePattern
<vector::PrintOp
> {
631 using OpRewritePattern
<vector::PrintOp
>::OpRewritePattern
;
633 LogicalResult
matchAndRewrite(vector::PrintOp printOp
,
634 PatternRewriter
&rewriter
) const override
{
635 if (!printOp
.getSource())
638 VectorType vectorType
= dyn_cast
<VectorType
>(printOp
.getPrintType());
639 if (!vectorType
|| !arm_sme::isValidSMETileVectorType(vectorType
))
642 auto loc
= printOp
.getLoc();
644 // Create a loop over the rows of the tile.
645 auto vscale
= rewriter
.create
<vector::VectorScaleOp
>(loc
);
647 rewriter
.create
<arith::ConstantIndexOp
>(loc
, vectorType
.getDimSize(0));
648 auto lowerBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
649 auto upperBound
= rewriter
.create
<arith::MulIOp
>(loc
, minTileRows
, vscale
);
650 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
651 auto forOp
= rewriter
.create
<scf::ForOp
>(loc
, lowerBound
, upperBound
, step
);
654 rewriter
.setInsertionPointToStart(forOp
.getBody());
655 // Extract the current row from the tile.
656 Value rowIndex
= forOp
.getInductionVar();
657 auto tileSlice
= rewriter
.create
<arm_sme::ExtractTileSliceOp
>(
658 loc
, printOp
.getSource(), rowIndex
);
659 // Print the row with a 1D vector.print.
660 rewriter
.create
<vector::PrintOp
>(loc
, tileSlice
,
661 printOp
.getPunctuation());
664 rewriter
.eraseOp(printOp
);
669 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
673 /// %slice = arm_sme.extract_tile_slice %tile[%index]
674 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
675 /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676 /// : vector<[4]xf32>, memref<?x?xf32>
680 /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681 /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
683 struct FoldTransferWriteOfExtractTileSlice
684 : public OpRewritePattern
<vector::TransferWriteOp
> {
685 using OpRewritePattern
<vector::TransferWriteOp
>::OpRewritePattern
;
687 LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp
,
688 PatternRewriter
&rewriter
) const final
{
689 if (!isa
<MemRefType
>(writeOp
.getSource().getType()))
690 return rewriter
.notifyMatchFailure(writeOp
, "destination not a memref");
692 if (writeOp
.hasOutOfBoundsDim())
693 return rewriter
.notifyMatchFailure(writeOp
,
694 "not inbounds transfer write");
696 auto extractTileSlice
=
697 writeOp
.getVector().getDefiningOp
<arm_sme::ExtractTileSliceOp
>();
698 if (!extractTileSlice
)
699 return rewriter
.notifyMatchFailure(
700 writeOp
, "vector to store not from ExtractTileSliceOp");
702 AffineMap map
= writeOp
.getPermutationMap();
703 if (!map
.isMinorIdentity())
704 return rewriter
.notifyMatchFailure(writeOp
,
705 "unsupported permutation map");
707 Value mask
= writeOp
.getMask();
709 auto maskType
= writeOp
.getVectorType().clone(rewriter
.getI1Type());
710 mask
= rewriter
.create
<arith::ConstantOp
>(
711 writeOp
.getLoc(), maskType
, DenseElementsAttr::get(maskType
, true));
714 rewriter
.replaceOpWithNewOp
<arm_sme::StoreTileSliceOp
>(
715 writeOp
, extractTileSlice
.getTile(),
716 extractTileSlice
.getTileSliceIndex(), mask
, writeOp
.getSource(),
717 writeOp
.getIndices(), extractTileSlice
.getLayout());
722 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724 /// SVE 2.1), so this is currently the most logical place for this lowering.
728 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729 /// %slice = vector.extract %mask[%index]
730 /// : vector<[8]xi1> from vector<[4]x[8]xi1>
734 /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735 /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737 /// : vector<[8]xi1>, vector<[4]xi1>
739 struct ExtractFromCreateMaskToPselLowering
740 : public OpRewritePattern
<vector::ExtractOp
> {
741 using OpRewritePattern
<vector::ExtractOp
>::OpRewritePattern
;
743 LogicalResult
matchAndRewrite(vector::ExtractOp extractOp
,
744 PatternRewriter
&rewriter
) const override
{
745 if (extractOp
.getNumIndices() != 1)
746 return rewriter
.notifyMatchFailure(extractOp
, "not single extract index");
748 auto resultType
= extractOp
.getResult().getType();
749 auto resultVectorType
= dyn_cast
<VectorType
>(resultType
);
750 if (!resultVectorType
)
751 return rewriter
.notifyMatchFailure(extractOp
, "result not VectorType");
754 extractOp
.getVector().getDefiningOp
<vector::CreateMaskOp
>();
756 return rewriter
.notifyMatchFailure(extractOp
, "source not CreateMaskOp");
758 auto maskType
= createMaskOp
.getVectorType();
759 if (maskType
.getRank() != 2 || !maskType
.allDimsScalable())
760 return rewriter
.notifyMatchFailure(createMaskOp
, "not 2-D scalable mask");
762 auto isSVEPredicateSize
= [](int64_t size
) {
763 return size
> 0 && size
<= 16 && llvm::isPowerOf2_32(uint32_t(size
));
766 auto rowsBaseSize
= maskType
.getDimSize(0);
767 auto colsBaseSize
= maskType
.getDimSize(1);
768 if (!isSVEPredicateSize(rowsBaseSize
) || !isSVEPredicateSize(colsBaseSize
))
769 return rewriter
.notifyMatchFailure(
770 createMaskOp
, "mask dimensions not SVE predicate-sized");
772 auto loc
= extractOp
.getLoc();
773 VectorType rowMaskType
= VectorType::Builder(maskType
).dropDim(1);
774 VectorType colMaskType
= VectorType::Builder(maskType
).dropDim(0);
776 // Create the two 1-D masks at the location of the 2-D create_mask (which is
777 // usually outside a loop). This prevents the need for later hoisting.
778 rewriter
.setInsertionPoint(createMaskOp
);
779 auto rowMask
= rewriter
.create
<vector::CreateMaskOp
>(
780 loc
, rowMaskType
, createMaskOp
.getOperand(0));
781 auto colMask
= rewriter
.create
<vector::CreateMaskOp
>(
782 loc
, colMaskType
, createMaskOp
.getOperand(1));
784 rewriter
.setInsertionPoint(extractOp
);
786 vector::getAsValues(rewriter
, loc
, extractOp
.getMixedPosition());
787 rewriter
.replaceOpWithNewOp
<arm_sve::PselOp
>(extractOp
, colMask
, rowMask
,
795 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet
&patterns
,
797 patterns
.add
<BroadcastOpToArmSMELowering
, SplatOpToArmSMELowering
,
798 TransferReadToArmSMELowering
, TransferWriteToArmSMELowering
,
799 TransposeOpToArmSMELowering
, VectorLoadToArmSMELowering
,
800 VectorStoreToArmSMELowering
, VectorOuterProductToArmSMELowering
,
801 VectorExtractToArmSMELowering
, VectorInsertToArmSMELowering
,
802 VectorPrintToArmSMELowering
, FoldTransferWriteOfExtractTileSlice
,
803 ExtractFromCreateMaskToPselLowering
>(&ctx
);