Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / VectorToArmSME / VectorToArmSME.cpp
blob55965d9c2a531d178d6518aa3d2494bbbe46c54b
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
11 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
13 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/Support/Casting.h"
18 using namespace mlir;
20 namespace {
22 /// Conversion pattern for vector.transfer_read.
23 ///
24 /// ---
25 ///
26 /// Example 1: op with identity permutation map to horizontal
27 /// arm_sme.tile_load:
28 ///
29 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
30 ///
31 /// is converted to:
32 ///
33 /// arm_sme.tile_load ...
34 ///
35 /// ---
36 ///
37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38 /// (in-flight transpose):
39 ///
40 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
41 ///
42 /// is converted to:
43 ///
44 /// arm_sme.tile_load ... layout<vertical>
45 struct TransferReadToArmSMELowering
46 : public OpRewritePattern<vector::TransferReadOp> {
47 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50 PatternRewriter &rewriter) const final {
51 // The permutation map must have two results.
52 if (transferReadOp.getTransferRank() != 2)
53 return rewriter.notifyMatchFailure(transferReadOp,
54 "not a 2 result permutation map");
56 auto vectorType = transferReadOp.getVectorType();
57 if (!arm_sme::isValidSMETileVectorType(vectorType))
58 return rewriter.notifyMatchFailure(transferReadOp,
59 "not a valid vector type for SME");
61 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
62 return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
64 // Out-of-bounds dims are not supported.
65 if (transferReadOp.hasOutOfBoundsDim())
66 return rewriter.notifyMatchFailure(transferReadOp,
67 "not inbounds transfer read");
69 AffineMap map = transferReadOp.getPermutationMap();
70 if (!map.isPermutation())
71 return rewriter.notifyMatchFailure(transferReadOp,
72 "unsupported permutation map");
74 // Note: For 2D vector types the only non-identity permutation is a simple
75 // transpose [1, 0].
76 bool transposed = !map.isIdentity();
77 arm_sme::TileSliceLayout layout =
78 transposed ? arm_sme::TileSliceLayout::Vertical
79 : arm_sme::TileSliceLayout::Horizontal;
81 // Padding isn't optional for transfer_read, but is only used in the case
82 // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83 // optional, if it's not present don't pass padding.
84 auto mask = transferReadOp.getMask();
85 auto padding = mask ? transferReadOp.getPadding() : nullptr;
86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87 transferReadOp, vectorType, transferReadOp.getSource(),
88 transferReadOp.getIndices(), padding, mask, layout);
90 return success();
94 /// Conversion pattern for vector.transfer_write.
95 ///
96 /// ---
97 ///
98 /// Example 1: op with identity permutation map to horizontal
99 /// arm_sme.tile_store:
101 /// vector.transfer_write %vector, %source[%c0, %c0]
102 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
104 /// is converted to:
106 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107 /// vector<[16]x[16]xi8>
108 /// ---
110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111 /// (in-flight transpose):
113 /// vector.transfer_write %vector, %source[%c0, %c0]
114 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
117 /// is converted to:
119 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120 /// : memref<?x?xi8>, vector<[16]x[16]xi8>
121 struct TransferWriteToArmSMELowering
122 : public OpRewritePattern<vector::TransferWriteOp> {
123 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126 PatternRewriter &rewriter) const final {
127 auto vType = writeOp.getVectorType();
128 if (!arm_sme::isValidSMETileVectorType(vType))
129 return failure();
131 if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
132 return failure();
134 // Out-of-bounds dims are not supported.
135 if (writeOp.hasOutOfBoundsDim())
136 return rewriter.notifyMatchFailure(writeOp,
137 "not inbounds transfer write");
139 AffineMap map = writeOp.getPermutationMap();
140 if (!map.isPermutation())
141 return rewriter.notifyMatchFailure(writeOp,
142 "unsupported permutation map");
144 // Note: For 2D vector types the only non-identity permutation is a simple
145 // transpose [1, 0].
146 bool transposed = !map.isIdentity();
147 arm_sme::TileSliceLayout layout =
148 transposed ? arm_sme::TileSliceLayout::Vertical
149 : arm_sme::TileSliceLayout::Horizontal;
151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
153 writeOp.getMask(), layout);
154 return success();
158 /// Conversion pattern for vector.load.
159 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
160 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
162 LogicalResult matchAndRewrite(vector::LoadOp load,
163 PatternRewriter &rewriter) const override {
164 if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165 return failure();
167 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168 load, load.getVectorType(), load.getBase(), load.getIndices());
170 return success();
174 /// Conversion pattern for vector.store.
175 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
176 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
178 LogicalResult matchAndRewrite(vector::StoreOp store,
179 PatternRewriter &rewriter) const override {
180 if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181 return failure();
183 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184 store, store.getValueToStore(), store.getBase(), store.getIndices());
186 return success();
190 /// Conversion pattern for vector.broadcast.
192 /// Example:
194 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
196 /// is converted to:
198 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201 /// {
202 /// %tile_update = arm_sme.insert_tile_slice
203 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204 /// vector<[4]xi32> into vector<[4]x[4]xi32>
205 /// scf.yield %tile_update : vector<[4]x[4]xi32>
206 /// }
208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209 struct BroadcastOpToArmSMELowering
210 : public OpRewritePattern<vector::BroadcastOp> {
211 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214 PatternRewriter &rewriter) const final {
215 auto tileType = broadcastOp.getResultVectorType();
216 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217 return failure();
219 auto loc = broadcastOp.getLoc();
221 auto srcType = broadcastOp.getSourceType();
222 auto srcVectorType = dyn_cast<VectorType>(srcType);
224 Value broadcastOp1D;
225 if (srcType.isIntOrFloat() ||
226 (srcVectorType && (srcVectorType.getRank() == 0))) {
227 // Broadcast scalar or 0-d vector to 1-d vector.
228 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229 broadcastOp1D = rewriter.create<vector::BroadcastOp>(
230 loc, tileSliceType, broadcastOp.getSource());
231 } else if (srcVectorType && (srcVectorType.getRank() == 1))
232 // Value to broadcast is already a 1-d vector, nothing to do.
233 broadcastOp1D = broadcastOp.getSource();
234 else
235 return failure();
237 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
239 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240 Value currentTile) {
241 // Create 'arm_sme.insert_tile_slice' to broadcast the value
242 // to each tile slice.
243 auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
244 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245 return nextTile.getResult();
248 // Create a loop over ZA tile slices.
249 auto forOp =
250 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
252 rewriter.replaceOp(broadcastOp, forOp.getResult(0));
254 return success();
258 /// Conversion pattern for vector.splat.
260 /// Example:
262 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
264 /// is converted to:
266 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269 /// {
270 /// %tile_update = arm_sme.insert_tile_slice
271 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272 /// vector<[4]xi32> into vector<[4]x[4]xi32>
273 /// scf.yield %tile_update : vector<[4]x[4]xi32>
274 /// }
276 /// This is identical to vector.broadcast of a scalar.
277 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
278 using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
280 LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281 PatternRewriter &rewriter) const final {
282 auto tileType = splatOp.getResult().getType();
283 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284 return failure();
286 auto loc = splatOp.getLoc();
287 auto srcType = splatOp.getOperand().getType();
289 assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290 // Avoid unused-variable warning when building without assertions.
291 (void)srcType;
293 // First, broadcast the scalar to a 1-d vector.
294 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295 Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
296 loc, tileSliceType, splatOp.getInput());
298 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
300 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301 Value currentTile) {
302 auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
303 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304 return nextTile.getResult();
307 // Next, create a loop over ZA tile slices and "move" the generated 1-d
308 // vector to each slice.
309 auto forOp =
310 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
312 rewriter.replaceOp(splatOp, forOp.getResult(0));
314 return success();
318 /// Conversion pattern for vector.transpose.
320 /// Stores the input tile to memory and reloads vertically.
322 /// Example:
324 /// %transposed_src = vector.transpose %src, [1, 0]
325 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
327 /// is converted to:
329 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331 /// : memref<?x?xi32>, vector<[4]x[4]xi32>
332 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
335 /// NOTE: Tranposing via memory is obviously expensive, the current intention
336 /// is to avoid the transpose if possible, this is therefore intended as a
337 /// fallback and to provide base support for Vector ops. If it turns out
338 /// transposes can't be avoided then this should be replaced with a more optimal
339 /// implementation, perhaps with tile <-> vector (MOVA) ops.
340 struct TransposeOpToArmSMELowering
341 : public OpRewritePattern<vector::TransposeOp> {
342 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
344 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
345 PatternRewriter &rewriter) const final {
346 auto tileType = transposeOp.getResultVectorType();
347 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
348 return failure();
350 // Bail unless this is a true 2-D matrix transpose.
351 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
352 if (permutation[0] != 1 || permutation[1] != 0)
353 return failure();
355 auto loc = transposeOp.getLoc();
356 Value input = transposeOp.getVector();
358 if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
359 xferOp && xferOp->hasOneUse()) {
360 // Fold transpose into transfer_read to enable in-flight transpose when
361 // converting to arm_sme.tile_load.
362 rewriter.modifyOpInPlace(xferOp, [&]() {
363 xferOp->setAttr(xferOp.getPermutationMapAttrName(),
364 AffineMapAttr::get(AffineMap::getPermutationMap(
365 permutation, transposeOp.getContext())));
367 rewriter.replaceOp(transposeOp, xferOp);
368 return success();
371 // Allocate buffer to store input tile to.
372 Value vscale =
373 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
374 Value minTileSlices = rewriter.create<arith::ConstantOp>(
375 loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
376 Value c0 =
377 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
378 Value numTileSlices =
379 rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
380 auto bufferType =
381 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
382 tileType.getElementType());
383 auto buffer = rewriter.create<memref::AllocaOp>(
384 loc, bufferType, ValueRange{numTileSlices, numTileSlices});
386 // Store input tile.
387 auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
388 loc, input, buffer, ValueRange{c0, c0});
390 // Reload input tile vertically.
391 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
392 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393 arm_sme::TileSliceLayout::Vertical);
395 return success();
399 /// Conversion pattern for vector.outerproduct.
401 /// If the vector.outerproduct is masked (and the mask is from a
402 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
403 /// operands.
405 /// Example:
407 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408 /// %result = vector.mask %mask {
409 /// vector.outerproduct %vecA, %vecB
410 /// : vector<[4]xf32>, vector<[4]xf32>
411 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
413 /// is converted to:
415 /// %maskA = vector.create_mask %dimA : vector<[4]xi1>
416 /// %maskB = vector.create_mask %dimB : vector<[4]xi1>
417 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418 /// : vector<[4]xf32>, vector<[4]xf32>
420 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
422 /// Example:
424 /// %result = vector.outerproduct %vecA, %vecB
425 /// : vector<[4]xf32>, vector<[4]xf32>
427 /// is converted to:
429 /// %result = arm_sme.outerproduct %vecA, %vecB
430 /// : vector<[4]xf32>, vector<[4]xf32>
432 struct VectorOuterProductToArmSMELowering
433 : public OpRewritePattern<vector::OuterProductOp> {
435 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
437 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
438 PatternRewriter &rewriter) const override {
440 // We don't yet support lowering AXPY operations to SME. These could be
441 // lowered by masking out all but the first element of the LHS.
442 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
443 return rewriter.notifyMatchFailure(outerProductOp,
444 "AXPY operations not supported");
446 if (!arm_sme::isValidSMETileVectorType(
447 outerProductOp.getResultVectorType()))
448 return rewriter.notifyMatchFailure(
449 outerProductOp, "outer product does not fit into SME tile");
451 auto kind = outerProductOp.getKind();
452 if (kind != vector::CombiningKind::ADD)
453 return rewriter.notifyMatchFailure(
454 outerProductOp,
455 "unsupported kind (lowering to SME only supports ADD at the moment)");
457 Value lhsMask = {};
458 Value rhsMask = {};
459 Operation *rootOp = outerProductOp;
460 auto loc = outerProductOp.getLoc();
461 if (outerProductOp.isMasked()) {
462 auto maskOp = outerProductOp.getMaskingOp();
463 rewriter.setInsertionPoint(maskOp);
464 rootOp = maskOp;
465 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
466 if (failed(operandMasks))
467 return failure();
468 std::tie(lhsMask, rhsMask) = *operandMasks;
471 rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
472 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
475 return success();
478 static FailureOr<std::pair<Value, Value>>
479 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
480 // Attempt to extract masks from vector.create_mask.
481 // TODO: Add support for other mask sources.
482 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
483 if (!createMaskOp)
484 return failure();
486 auto maskType = createMaskOp.getVectorType();
487 Value lhsMaskDim = createMaskOp.getOperand(0);
488 Value rhsMaskDim = createMaskOp.getOperand(1);
490 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
491 Value lhsMask =
492 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
493 Value rhsMask =
494 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
496 return std::make_pair(lhsMask, rhsMask);
500 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
502 /// Example:
503 /// ```
504 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
505 /// ```
506 /// Becomes:
507 /// ```
508 /// %slice = arm_sme.extract_tile_slice %tile[%row]
509 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
510 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511 /// ```
512 struct VectorExtractToArmSMELowering
513 : public OpRewritePattern<vector::ExtractOp> {
514 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
516 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
517 PatternRewriter &rewriter) const override {
518 VectorType sourceType = extractOp.getSourceVectorType();
519 if (!arm_sme::isValidSMETileVectorType(sourceType))
520 return failure();
522 auto loc = extractOp.getLoc();
523 auto position = extractOp.getMixedPosition();
525 Value sourceVector = extractOp.getVector();
527 // Extract entire vector. Should be handled by folder, but just to be safe.
528 if (position.empty()) {
529 rewriter.replaceOp(extractOp, sourceVector);
530 return success();
533 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
534 auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
535 loc, sourceVector, sliceIndex);
537 if (position.size() == 1) {
538 // Single index case: Extracts a 1D slice.
539 rewriter.replaceOp(extractOp, extractTileSlice);
540 return success();
543 // Two indices case: Extracts a single element.
544 assert(position.size() == 2);
545 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
546 position[1]);
548 return success();
552 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553 /// `arm_sme.extract_tile_slice`.
555 /// Example:
556 /// ```
557 /// %new_tile = vector.insert %el, %tile[%row, %col]
558 /// : i32 into vector<[4]x[4]xi32>
559 /// ```
560 /// Becomes:
561 /// ```
562 /// %slice = arm_sme.extract_tile_slice %tile[%row]
563 /// : vector<[4]xi32> from vector<[4]x[4]xi32>
564 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566 /// : vector<[4]xi32> into vector<[4]x[4]xi32>
567 /// ```
568 struct VectorInsertToArmSMELowering
569 : public OpRewritePattern<vector::InsertOp> {
570 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
572 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
573 PatternRewriter &rewriter) const override {
574 VectorType resultType = insertOp.getResult().getType();
576 if (!arm_sme::isValidSMETileVectorType(resultType))
577 return failure();
579 auto loc = insertOp.getLoc();
580 auto position = insertOp.getMixedPosition();
582 Value source = insertOp.getSource();
584 // Overwrite entire vector with value. Should be handled by folder, but
585 // just to be safe.
586 if (position.empty()) {
587 rewriter.replaceOp(insertOp, source);
588 return success();
591 Value tileSlice = source;
592 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
593 if (position.size() == 2) {
594 // Two indices case: Insert single element into tile.
595 // We need to first extract the existing slice and update the element.
596 tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
597 loc, insertOp.getDest(), sliceIndex);
598 tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
599 position[1]);
602 // Insert the slice into the destination tile.
603 rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
604 insertOp, tileSlice, insertOp.getDest(), sliceIndex);
605 return success();
609 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610 /// extracting them via `arm_sme.extract_tile_slice`, then printing with
611 /// a 1D `vector.print`.
613 /// BEFORE:
614 /// ```mlir
615 /// vector.print %tile : vector<[4]x[4]xf32>
616 /// ```
617 /// AFTER:
618 /// ```mlir
619 /// %c0 = arith.constant 0 : index
620 /// %c1 = arith.constant 1 : index
621 /// %c4 = arith.constant 4 : index
622 /// %vscale = vector.vscale
623 /// %svl_s = arith.muli %c4, %vscale : index
624 /// scf.for %i = %c0 to %svl_s step %c1 {
625 /// %tile_slice = arm_sme.extract_tile_slice %tile[%i]
626 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
627 /// vector.print %tile_slice : vector<[4]xf32>
628 /// }
629 /// ```
630 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
631 using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
633 LogicalResult matchAndRewrite(vector::PrintOp printOp,
634 PatternRewriter &rewriter) const override {
635 if (!printOp.getSource())
636 return failure();
638 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
639 if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
640 return failure();
642 auto loc = printOp.getLoc();
644 // Create a loop over the rows of the tile.
645 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
646 auto minTileRows =
647 rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
648 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
649 auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
650 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
651 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
653 // Loop body.
654 rewriter.setInsertionPointToStart(forOp.getBody());
655 // Extract the current row from the tile.
656 Value rowIndex = forOp.getInductionVar();
657 auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
658 loc, printOp.getSource(), rowIndex);
659 // Print the row with a 1D vector.print.
660 rewriter.create<vector::PrintOp>(loc, tileSlice,
661 printOp.getPunctuation());
664 rewriter.eraseOp(printOp);
665 return success();
669 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
671 /// BEFORE:
672 /// ```mlir
673 /// %slice = arm_sme.extract_tile_slice %tile[%index]
674 /// : vector<[4]xf32> from vector<[4]x[4]xf32>
675 /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676 /// : vector<[4]xf32>, memref<?x?xf32>
677 /// ```
678 /// AFTER:
679 /// ```mlir
680 /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681 /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682 /// ```
683 struct FoldTransferWriteOfExtractTileSlice
684 : public OpRewritePattern<vector::TransferWriteOp> {
685 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
687 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
688 PatternRewriter &rewriter) const final {
689 if (!isa<MemRefType>(writeOp.getSource().getType()))
690 return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
692 if (writeOp.hasOutOfBoundsDim())
693 return rewriter.notifyMatchFailure(writeOp,
694 "not inbounds transfer write");
696 auto extractTileSlice =
697 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698 if (!extractTileSlice)
699 return rewriter.notifyMatchFailure(
700 writeOp, "vector to store not from ExtractTileSliceOp");
702 AffineMap map = writeOp.getPermutationMap();
703 if (!map.isMinorIdentity())
704 return rewriter.notifyMatchFailure(writeOp,
705 "unsupported permutation map");
707 Value mask = writeOp.getMask();
708 if (!mask) {
709 auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
710 mask = rewriter.create<arith::ConstantOp>(
711 writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
714 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715 writeOp, extractTileSlice.getTile(),
716 extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
717 writeOp.getIndices(), extractTileSlice.getLayout());
718 return success();
722 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724 /// SVE 2.1), so this is currently the most logical place for this lowering.
726 /// Example:
727 /// ```mlir
728 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729 /// %slice = vector.extract %mask[%index]
730 /// : vector<[8]xi1> from vector<[4]x[8]xi1>
731 /// ```
732 /// Becomes:
733 /// ```
734 /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735 /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737 /// : vector<[8]xi1>, vector<[4]xi1>
738 /// ```
739 struct ExtractFromCreateMaskToPselLowering
740 : public OpRewritePattern<vector::ExtractOp> {
741 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
743 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
744 PatternRewriter &rewriter) const override {
745 if (extractOp.getNumIndices() != 1)
746 return rewriter.notifyMatchFailure(extractOp, "not single extract index");
748 auto resultType = extractOp.getResult().getType();
749 auto resultVectorType = dyn_cast<VectorType>(resultType);
750 if (!resultVectorType)
751 return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
753 auto createMaskOp =
754 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
755 if (!createMaskOp)
756 return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
758 auto maskType = createMaskOp.getVectorType();
759 if (maskType.getRank() != 2 || !maskType.allDimsScalable())
760 return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
762 auto isSVEPredicateSize = [](int64_t size) {
763 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
766 auto rowsBaseSize = maskType.getDimSize(0);
767 auto colsBaseSize = maskType.getDimSize(1);
768 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
769 return rewriter.notifyMatchFailure(
770 createMaskOp, "mask dimensions not SVE predicate-sized");
772 auto loc = extractOp.getLoc();
773 VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
774 VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
776 // Create the two 1-D masks at the location of the 2-D create_mask (which is
777 // usually outside a loop). This prevents the need for later hoisting.
778 rewriter.setInsertionPoint(createMaskOp);
779 auto rowMask = rewriter.create<vector::CreateMaskOp>(
780 loc, rowMaskType, createMaskOp.getOperand(0));
781 auto colMask = rewriter.create<vector::CreateMaskOp>(
782 loc, colMaskType, createMaskOp.getOperand(1));
784 rewriter.setInsertionPoint(extractOp);
785 auto position =
786 vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
787 rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
788 position[0]);
789 return success();
793 } // namespace
795 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
796 MLIRContext &ctx) {
797 patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
798 TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
801 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803 ExtractFromCreateMaskToPselLowering>(&ctx);