Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / ArmSMEToLLVM / ArmSMEToLLVM.cpp
blob40a3489f7a4d7bbe821be5ebb9d8c19ad286e84a
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
20 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/ScopeExit.h"
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
35 using namespace mlir;
37 namespace {
39 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42 static Operation *createLoadTileSliceIntrinsic(
43 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
44 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
45 IntegerAttr tileId, Value tileSliceI32) {
46 if (layout == arm_sme::TileSliceLayout::Horizontal) {
47 switch (type) {
48 case arm_sme::ArmSMETileType::ZAB:
49 return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
50 loc, maskOp, ptr, tileId, tileSliceI32);
51 case arm_sme::ArmSMETileType::ZAH:
52 return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
53 loc, maskOp, ptr, tileId, tileSliceI32);
54 case arm_sme::ArmSMETileType::ZAS:
55 return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
56 loc, maskOp, ptr, tileId, tileSliceI32);
57 case arm_sme::ArmSMETileType::ZAD:
58 return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
59 loc, maskOp, ptr, tileId, tileSliceI32);
60 case arm_sme::ArmSMETileType::ZAQ:
61 return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
62 loc, maskOp, ptr, tileId, tileSliceI32);
64 } else {
65 switch (type) {
66 case arm_sme::ArmSMETileType::ZAB:
67 return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
68 loc, maskOp, ptr, tileId, tileSliceI32);
69 case arm_sme::ArmSMETileType::ZAH:
70 return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
71 loc, maskOp, ptr, tileId, tileSliceI32);
72 case arm_sme::ArmSMETileType::ZAS:
73 return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
74 loc, maskOp, ptr, tileId, tileSliceI32);
75 case arm_sme::ArmSMETileType::ZAD:
76 return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
77 loc, maskOp, ptr, tileId, tileSliceI32);
78 case arm_sme::ArmSMETileType::ZAQ:
79 return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
80 loc, maskOp, ptr, tileId, tileSliceI32);
81 break;
84 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
87 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
88 static Operation *createStoreTileSliceIntrinsic(
89 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
90 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
91 IntegerAttr tileId, Value tileSliceI32) {
92 if (layout == arm_sme::TileSliceLayout::Horizontal) {
93 switch (type) {
94 case arm_sme::ArmSMETileType::ZAB:
95 return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
96 loc, maskOp, ptr, tileId, tileSliceI32);
97 case arm_sme::ArmSMETileType::ZAH:
98 return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
99 loc, maskOp, ptr, tileId, tileSliceI32);
100 case arm_sme::ArmSMETileType::ZAS:
101 return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
102 loc, maskOp, ptr, tileId, tileSliceI32);
103 case arm_sme::ArmSMETileType::ZAD:
104 return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
105 loc, maskOp, ptr, tileId, tileSliceI32);
106 case arm_sme::ArmSMETileType::ZAQ:
107 return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
108 loc, maskOp, ptr, tileId, tileSliceI32);
110 } else {
111 switch (type) {
112 case arm_sme::ArmSMETileType::ZAB:
113 return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
114 loc, maskOp, ptr, tileId, tileSliceI32);
115 case arm_sme::ArmSMETileType::ZAH:
116 return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
117 loc, maskOp, ptr, tileId, tileSliceI32);
118 case arm_sme::ArmSMETileType::ZAS:
119 return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
120 loc, maskOp, ptr, tileId, tileSliceI32);
121 case arm_sme::ArmSMETileType::ZAD:
122 return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
123 loc, maskOp, ptr, tileId, tileSliceI32);
124 case arm_sme::ArmSMETileType::ZAQ:
125 return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
126 loc, maskOp, ptr, tileId, tileSliceI32);
129 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
132 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
133 auto tileId = op.getTileId();
134 if (!tileId)
135 op.emitOpError(
136 "expected tile ID to be allocated before conversion to LLVM");
137 return tileId;
140 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
141 /// placed in the first block of the function.
142 static memref::AllocaOp
143 createAllocaForTile(RewriterBase &rewriter, Location loc,
144 FunctionOpInterface func,
145 arm_sme::ArmSMETileOpInterface tileOp) {
146 RewriterBase::InsertionGuard g(rewriter);
147 // Move to the first operation in the function.
148 rewriter.setInsertionPointToStart(&func.getBlocks().front());
149 // Create an alloca matching the tile size of the `tileOp`.
150 auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
151 auto tileElementType = tileOp.getTileType().getElementType();
152 auto memrefType = MemRefType::get(
153 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
154 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
155 auto minElementsOp =
156 rewriter.create<arith::ConstantIndexOp>(loc, minElements);
157 auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
158 auto alloca = rewriter.create<memref::AllocaOp>(
159 loc, memrefType, ValueRange{vectorLen, vectorLen});
160 return alloca;
163 /// Finds or creates an alloca for a spill of a tile.
164 static memref::AllocaOp getOrCreateAllocaForTile(
165 RewriterBase &rewriter, Location loc, FunctionOpInterface func,
166 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
167 // Find an alloca at the top of the function tagged with a
168 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
169 for (auto &op : func.getBlocks().front()) {
170 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
171 if (!alloca)
172 continue;
173 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
174 alloca->getDiscardableAttr(kInMemoryTileIdAttr));
175 if (!inMemoryTileId)
176 continue;
177 if (inMemoryTileId.getInt() == tileId)
178 return alloca;
180 // Otherwise, create a new alloca:
181 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
182 alloca->setDiscardableAttr(kInMemoryTileIdAttr,
183 rewriter.getI32IntegerAttr(tileId));
184 return alloca;
187 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
188 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
189 /// the op to tile 0, then emitting a full tile swap between ZA and memory
190 /// before + after the tile op.
192 /// Example:
194 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
195 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
197 /// is converted to:
198 /// // At function entry:
199 /// %spill = memref.alloca ... : memref<?x?xty>
201 /// // Around op:
202 /// scf.for %slice_idx {
203 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
204 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
205 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
206 /// }
207 /// arm_sme.tile_op { tile_id = 0 }
208 /// scf.for %slice_idx {
209 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
210 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
211 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
212 /// }
214 /// Note that these spills/fills are not inserted earlier as concept of a
215 /// register, and the need to swap the contents, can't really be represented
216 /// correctly at a high level in MLIR.
218 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
219 /// redundant reloads). This could be done via a method on the
220 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
222 /// `tileOp.getZaUsage()` could return:
224 /// struct ArmSMEOpZAUsage {
225 /// enum class Kind {
226 /// TileRead, // Omit store after tile operation.
227 /// TileWrite, // Omit load before tile operation.
228 /// TileReadWrite, // Needs both tile load and store.
229 /// SliceRead, // Spill single slice and omit store after operation.
230 /// SliceWrite, // Spill single slice and omit load before operation.
231 /// SliceReadWrite // Spill single slice.
232 /// };
233 /// Value sliceIndex {};
234 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
235 /// };
237 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
239 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
240 const LLVMTypeConverter &typeConverter,
241 PatternBenefit benefit)
242 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
243 typeConverter, benefit) {}
245 LogicalResult
246 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
247 ConversionPatternRewriter &rewriter) const override {
248 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
249 // Tile has a real (hardware) tile. No spills/reloads required.
250 if (!tileOp.isInMemoryTile())
251 return failure();
253 tileOp->emitWarning(
254 "failed to allocate SME virtual tile to operation, tile value will go "
255 "through memory, expect degraded performance");
257 // Step 1. Create an alloca for the tile at the top of the function (if one
258 // does not already exist).
259 auto loc = tileOp.getLoc();
260 auto func = tileOp->getParentOfType<FunctionOpInterface>();
261 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
262 tileOp.getTileId().getInt());
264 // Step 2. Assign the op a real tile ID.
265 // For simplicity, we always use tile 0 (which always exists).
266 auto zeroTileId = rewriter.getI32IntegerAttr(0);
267 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
269 VectorType tileVectorType = tileOp.getTileType();
270 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
271 auto swapInMemoryTileWithSMETileZero = [&] {
272 emitFullTileSwap(rewriter, loc, tileAlloca,
273 *arm_sme::getSMETileType(tileVectorType), sliceType,
274 zeroTileId);
277 // Step 3. Emit tile swaps before and after the op.
278 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
279 // touches (i.e. a single tile slice).
281 rewriter.setInsertionPoint(op);
282 // Swap the contents of ZA and the in-memory tile before the op.
283 swapInMemoryTileWithSMETileZero();
284 rewriter.setInsertionPointAfter(op);
285 // Swap the tile back out to memory again after the op.
286 swapInMemoryTileWithSMETileZero();
289 return success();
292 /// Extracts a pointer to a slice of an in-memory tile.
293 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
294 Value tileMemory, Value sliceIndex) const {
295 auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
296 auto descriptor =
297 rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
298 auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
299 auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300 loc, rewriter.getI64Type(), sliceIndex);
301 return getStridedElementPtr(
302 loc, llvm::cast<MemRefType>(tileMemory.getType()),
303 descriptor.getResult(0), {sliceIndexI64, zero},
304 static_cast<ConversionPatternRewriter &>(rewriter));
307 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
308 /// tile-sized memref (`tileAlloca`).
309 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
310 arm_sme::ArmSMETileType tileType, VectorType sliceType,
311 IntegerAttr tileId, Value sliceIndex) const {
312 // Cast the slice index to an i32.
313 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
314 loc, rewriter.getI32Type(), sliceIndex);
315 // Create an all-true predicate for the slice.
316 auto predicateType = sliceType.clone(rewriter.getI1Type());
317 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
318 loc, DenseElementsAttr::get(predicateType, true));
319 // Create padding vector (never used due to all-true predicate).
320 auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
321 // Get a pointer to the current slice.
322 auto slicePtr =
323 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
324 // Read the value of the current slice from ZA.
325 auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
326 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
327 // Load the new tile slice back from memory into ZA.
328 createLoadTileSliceIntrinsic(
329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330 allTruePredicate, slicePtr, tileId, sliceIndexI32);
331 // Store the current tile slice to memory.
332 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
333 rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
334 ValueRange{sliceIndex, zero});
337 /// Emits a full in-place swap of the contents of a tile in ZA and a
338 /// tile-sized memref (`tileAlloca`).
339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340 arm_sme::ArmSMETileType tileType, VectorType sliceType,
341 IntegerAttr tileId) const {
342 RewriterBase::InsertionGuard guard(rewriter);
343 // Create an scf.for over all tile slices.
344 auto minNumElts =
345 rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
346 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
347 auto upperBound = rewriter.create<arith::MulIOp>(
348 loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
349 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
350 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
351 // Emit a swap for each tile slice.
352 rewriter.setInsertionPointToStart(forOp.getBody());
353 auto sliceIndex = forOp.getInductionVar();
354 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
355 sliceIndex);
359 enum class RequiresSpillsAndFills { Yes, No };
361 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
362 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
363 /// disabled by setting the `requiresSpillsAndFills` template parameter to
364 /// `RequiresSpillsAndFills::No`.
365 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
366 RequiresSpillsAndFills::Yes>
367 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
368 using ArmSMEOp = SourceOp;
369 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
371 static constexpr bool requiresSpillsAndFillsConversion() {
372 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
376 template <typename Pattern>
377 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
378 LLVMTypeConverter const &typeConverter) {
379 // Register spills/fills for ops that implement the
380 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
381 // `RequiresSpillsAndFills::Yes`.
382 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
384 typename Pattern::ArmSMEOp>,
385 typename Pattern::ArmSMEOp>) {
386 // Add spill/fill conversions with a very high benefit to ensure
387 // they are lowered first.
388 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
389 Pattern::ArmSMEOp::getOperationName(), typeConverter,
390 /*benefit=*/1337);
392 patterns.add<Pattern>(typeConverter);
395 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
396 template <typename... Patterns>
397 static void
398 addArmSMEConversionPatterns(RewritePatternSet &patterns,
399 LLVMTypeConverter const &typeConverter) {
400 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
403 /// Lower 'arm_sme.zero' to SME intrinsics.
405 /// BEFORE:
406 /// ```mlir
407 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
408 /// ```
410 /// AFTER:
411 /// ```mlir
412 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
413 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
414 /// ```
416 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
417 /// ArmSME ops have been converted to LLVM intrinsics.
418 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
419 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
421 LogicalResult
422 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const override {
424 auto loc = zero.getLoc();
426 auto tileId = getTileIdOrError(zero);
427 if (!tileId)
428 return failure();
430 // Get the base mask for tile based on the element size.
431 // The base mask is just the mask to zero the first tile (of a size).
432 // These masks are derived from:
433 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
434 arm_sme::ArmSMETileType tileType =
435 *arm_sme::getSMETileType(zero.getTileType());
436 auto baseMaskForSize = [&] {
437 switch (tileType) {
438 case arm_sme::ArmSMETileType::ZAB:
439 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
440 // 64-bit element tiles named ZA0.D to ZA7.D.
441 return 0b1111'1111;
442 case arm_sme::ArmSMETileType::ZAH:
443 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
444 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
445 // once for ZA1.H.
446 return 0b0101'0101;
447 case arm_sme::ArmSMETileType::ZAS:
448 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
449 // element tiles named ZA0.D and ZA4.D.
450 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
451 return 0b0001'0001;
452 case arm_sme::ArmSMETileType::ZAD:
453 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
454 // setting the bit for that tile.
455 return 0b0000'0001;
456 default:
457 llvm_unreachable("bad element size");
459 }();
461 // The actual mask is just the base mask shifted by the tile ID.
462 // This will be folded to a constant after tile allocation.
464 // The shift is just derived from the layout of the tiles, and that the tile
465 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
466 // tiles:
468 // ZA0.S = ZA0.D and ZA4.D
469 // * Tile ID -> 0
470 // * Mask -> 00010001 = (00010001 << 0)
471 // ZA1.S = ZA1.D and ZA5.D
472 // * Tile ID -> 1
473 // * Mask -> 00100010 = (00010001 << 1)
474 // ZA2.S = ZA2.D and ZA6.D
475 // * Tile ID -> 2
476 // * Mask -> 01000100 = (00010001 << 2)
477 // ZA3.S = ZA3.D and ZA7.D
478 // * Tile ID -> 3
479 // * Mask -> 10001000 = (00010001 << 3)
481 // This holds for all tile sizes.
482 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
483 rewriter.create<arm_sme::aarch64_sme_zero>(
484 loc, rewriter.getI32IntegerAttr(zeroMask));
486 // Create a placeholder op to preserve dataflow.
487 // Note: Place the `get_tile` op at the start of the block. This ensures
488 // that if there are multiple `zero` ops the intrinsics will be consecutive.
489 rewriter.setInsertionPointToStart(zero->getBlock());
490 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
492 return success();
496 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
497 struct LoadTileSliceConversion
498 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
499 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
501 LogicalResult
502 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
503 arm_sme::LoadTileSliceOp::Adaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 auto loc = loadTileSliceOp.getLoc();
506 auto tileId = getTileIdOrError(loadTileSliceOp);
507 if (!tileId)
508 return failure();
510 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511 adaptor.getBase(),
512 adaptor.getIndices(), rewriter);
514 auto tileSlice = loadTileSliceOp.getTileSliceIndex();
516 // Cast tile slice to i32 for intrinsic.
517 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
518 loc, rewriter.getI32Type(), tileSlice);
520 // Create all active predicate mask.
521 auto maskOp = loadTileSliceOp.getMask();
523 auto tileVectorType = loadTileSliceOp.getVectorType();
524 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
525 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
527 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
528 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
529 tileId, tileSliceI32);
531 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
532 // the input tile to preserve dataflow.
533 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
535 return success();
539 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
540 struct StoreTileSliceConversion
541 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
542 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
544 LogicalResult
545 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
546 arm_sme::StoreTileSliceOp::Adaptor adaptor,
547 ConversionPatternRewriter &rewriter) const override {
548 auto loc = storeTileSliceOp.getLoc();
549 auto tileVectorType = storeTileSliceOp.getVectorType();
551 auto tileId = getTileIdOrError(storeTileSliceOp);
552 if (!tileId)
553 return failure();
555 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556 Value ptr = this->getStridedElementPtr(
557 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558 adaptor.getIndices(), rewriter);
560 auto tileSlice = storeTileSliceOp.getTileSliceIndex();
562 // Cast tile slice to i32 for intrinsic.
563 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
564 loc, rewriter.getI32Type(), tileSlice);
566 auto maskOp = storeTileSliceOp.getMask();
568 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
569 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
571 rewriter.replaceOp(storeTileSliceOp,
572 createStoreTileSliceIntrinsic(rewriter, loc, tileType,
573 layout, maskOp, ptr,
574 tileId, tileSliceI32));
576 return success();
580 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
581 struct InsertTileSliceConversion
582 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
583 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
585 LogicalResult
586 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
587 arm_sme::InsertTileSliceOp::Adaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 auto loc = insertTileSliceOp.getLoc();
590 auto tileType = insertTileSliceOp.getTileType();
592 auto tileId = getTileIdOrError(insertTileSliceOp);
593 if (!tileId)
594 return failure();
596 auto tileSlice = insertTileSliceOp.getTileSliceIndex();
598 // Cast tile slice from index to i32 for intrinsic.
599 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
600 loc, rewriter.getI32Type(), tileSlice);
602 // Create all active predicate mask.
603 auto one = rewriter.create<arith::ConstantOp>(
604 loc, rewriter.getI1Type(),
605 rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
606 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
607 /*scalableDims=*/{true});
608 auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
610 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
611 switch (insertTileSliceOp.getLayout()) {
612 case arm_sme::TileSliceLayout::Horizontal:
613 rewriter.create<arm_sme::aarch64_sme_write_horiz>(
614 loc, tileId, tileSliceI32, allActiveMask,
615 insertTileSliceOp.getVector());
616 break;
617 case arm_sme::TileSliceLayout::Vertical:
618 rewriter.create<arm_sme::aarch64_sme_write_vert>(
619 loc, tileId, tileSliceI32, allActiveMask,
620 insertTileSliceOp.getVector());
621 break;
624 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
625 // the input tile to preserve dataflow.
626 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
628 return success();
632 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
633 struct ExtractTileSliceConversion
634 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
635 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
637 LogicalResult
638 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 auto loc = extractTileSlice.getLoc();
641 auto sliceType = extractTileSlice.getSliceType();
642 auto sliceIndex = extractTileSlice.getTileSliceIndex();
644 auto tileId = getTileIdOrError(extractTileSlice);
645 if (!tileId)
646 return failure();
648 // Create an 'all true' predicate for the tile slice.
649 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
650 auto allTruePredicate = rewriter.create<arith::ConstantOp>(
651 loc, DenseElementsAttr::get(predicateType, true));
653 // Zero destination/fallback for tile slice extraction.
654 auto zeroVector = rewriter.create<arith::ConstantOp>(
655 loc, sliceType, rewriter.getZeroAttr(sliceType));
657 // Cast tile slice from index to i32 for intrinsic.
658 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
659 loc, rewriter.getI32Type(), sliceIndex);
661 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662 switch (extractTileSlice.getLayout()) {
663 case arm_sme::TileSliceLayout::Horizontal:
664 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
665 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
666 sliceIndexI32);
667 break;
668 case arm_sme::TileSliceLayout::Vertical:
669 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
670 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671 sliceIndexI32);
672 break;
675 return success();
679 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
681 /// Example:
683 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
684 /// : vector<[4]xf32>, vector<[4]xf32>
686 /// is converted to:
688 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
689 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
690 /// vector<[4]xf32>) -> ()
692 /// Currently only supports FMOPA and BFMOPA (non-widening).
693 struct OuterProductOpConversion
694 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
695 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
697 LogicalResult
698 matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
699 arm_sme::OuterProductOp::Adaptor adaptor,
700 ConversionPatternRewriter &rewriter) const override {
701 auto tileId = getTileIdOrError(outerProductOp);
702 if (!tileId)
703 return failure();
705 auto isSupportedType = [](VectorType vectorType) {
706 // TODO: the FP outer product instruction variants are predicated on
707 // different features [1]:
709 // * FMOPA (non-widening)
710 // * half-precision - +sme2p1,+sme-f16f16
711 // * single-precision - +sme
712 // * double-precision - +sme-f64f64
713 // * BFMOPA
714 // * half-precision - +sme2p1,+b16b16
716 // It should be possible to control lowering based on target features.
717 // [1]
718 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
719 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
720 return false;
722 auto elementType = vectorType.getElementType();
724 if (!elementType.isF16() && !elementType.isBF16() &&
725 !elementType.isF32() && !elementType.isF64())
726 return false;
728 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
729 vectorType.getElementTypeBitWidth();
730 return vectorType.getShape() ==
731 ArrayRef<int64_t>({minNumElts, minNumElts});
734 // TODO: Support CombiningKind::Sub for outer products.
735 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
736 return outerProductOp.emitError("unsupported kind");
738 auto resultVectorType = outerProductOp.getResultType();
739 if (!isSupportedType(resultVectorType))
740 return outerProductOp.emitError("unsupported type");
742 auto loc = outerProductOp.getLoc();
744 Value acc = outerProductOp.getAcc();
745 if (!acc) {
746 // Initalize accumulator with zero.
747 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
748 zero.setTileId(tileId);
749 acc = zero;
752 Value lhsMask = outerProductOp.getLhsMask();
753 Value rhsMask = outerProductOp.getRhsMask();
755 if (!lhsMask || !rhsMask) {
756 auto predTy =
757 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
758 Value allActiveMask = rewriter.create<arith::ConstantOp>(
759 loc, DenseElementsAttr::get(predTy, true));
760 lhsMask = allActiveMask;
761 rhsMask = allActiveMask;
764 // Create 'arm_sme.intr.mopa' outer product intrinsic.
765 rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
766 outerProductOp.getLhs(),
767 outerProductOp.getRhs());
769 // The outerproduct intrinsics have no result, replace
770 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
771 rewriter.replaceOp(outerProductOp, acc);
773 return success();
777 /// Lower 2-way and 4-way widening outer products to intrinsics.
778 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
779 struct OuterProductWideningOpConversion
780 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
781 using ConvertArmSMEOpToLLVMPattern<
782 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
784 LogicalResult
785 matchAndRewrite(OuterProductWideningOp op,
786 typename OuterProductWideningOp::Adaptor adaptor,
787 ConversionPatternRewriter &rewriter) const override {
788 auto tileId = getTileIdOrError(op);
789 if (!tileId)
790 return failure();
792 auto loc = op.getLoc();
793 Value acc = op.getAcc();
794 if (!acc) {
795 // Initalize accumulator with zero.
796 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
797 zero.setTileId(tileId);
798 acc = zero;
801 Value lhsMask = op.getLhsMask();
802 Value rhsMask = op.getRhsMask();
803 if (!lhsMask || !rhsMask) {
804 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
805 Value allActiveMask = rewriter.create<arith::ConstantOp>(
806 loc, DenseElementsAttr::get(predTy, true));
807 lhsMask = allActiveMask;
808 rhsMask = allActiveMask;
811 rewriter.create<OuterProductWideningIntrOp>(
812 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
814 // The outerproduct intrinsics have no result, replace
815 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816 rewriter.replaceOp(op, acc);
818 return success();
822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
824 /// Example:
826 /// %0 = arm_sme.streaming_vl <half>
828 /// is converted to:
830 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831 /// %0 = arith.index_cast %cnt : i64 to index
833 struct StreamingVLOpConversion
834 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835 RequiresSpillsAndFills::No> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
838 LogicalResult
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840 arm_sme::StreamingVLOp::Adaptor adaptor,
841 ConversionPatternRewriter &rewriter) const override {
842 auto loc = streamingVlOp.getLoc();
843 auto i64Type = rewriter.getI64Type();
844 auto *intrOp = [&]() -> Operation * {
845 switch (streamingVlOp.getTypeSize()) {
846 case arm_sme::TypeSize::Byte:
847 return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848 case arm_sme::TypeSize::Half:
849 return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850 case arm_sme::TypeSize::Word:
851 return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852 case arm_sme::TypeSize::Double:
853 return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
855 llvm_unreachable("unknown type size in StreamingVLOpConversion");
856 }();
857 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
858 streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
859 return success();
863 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
864 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
865 static void mergeConsecutiveTileZerosInBlock(Block *block) {
866 uint32_t mergedZeroMask = 0;
867 SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
868 auto replaceMergedZeroOps = [&] {
869 auto cleanup = llvm::make_scope_exit([&] {
870 mergedZeroMask = 0;
871 zeroOpsToMerge.clear();
873 if (zeroOpsToMerge.size() <= 1)
874 return;
875 IRRewriter rewriter(zeroOpsToMerge.front());
876 rewriter.create<arm_sme::aarch64_sme_zero>(
877 zeroOpsToMerge.front().getLoc(),
878 rewriter.getI32IntegerAttr(mergedZeroMask));
879 for (auto zeroOp : zeroOpsToMerge)
880 rewriter.eraseOp(zeroOp);
882 for (Operation &op : *block) {
883 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
884 mergedZeroMask |= zeroOp.getTileMask();
885 zeroOpsToMerge.push_back(zeroOp);
886 } else {
887 replaceMergedZeroOps();
890 replaceMergedZeroOps();
893 } // namespace
895 namespace {
897 struct ConvertArmSMEToLLVMPass
898 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
899 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
900 this->dumpTileLiveRanges = dumpTileLiveRanges;
902 void runOnOperation() override {
903 auto function = getOperation();
905 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
906 return signalPassFailure();
908 LLVMConversionTarget target(getContext());
909 RewritePatternSet patterns(&getContext());
910 LLVMTypeConverter converter(&getContext());
911 configureArmSMEToLLVMConversionLegality(target);
912 populateArmSMEToLLVMConversionPatterns(converter, patterns);
914 if (failed(applyPartialConversion(function, target, std::move(patterns))))
915 signalPassFailure();
917 function->walk(mergeConsecutiveTileZerosInBlock);
919 // Walk the function and fail if there are unexpected operations on SME
920 // tile types after conversion.
921 function->walk([&](Operation *op) {
922 // These ops are legal post conversion, skip these.
923 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
924 !op->isRegistered())
925 return;
926 auto isSMETileType = [](Type type) {
927 return arm_sme::isValidSMETileVectorType(type);
929 if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
930 llvm::any_of(op->getOperandTypes(), isSMETileType)) {
931 op->emitOpError("unexpected operation with SME tile type after "
932 "conversion to LLVM");
933 signalPassFailure();
939 } // namespace
941 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
942 target.addIllegalDialect<arm_sme::ArmSMEDialect>();
943 target.addLegalOp<
944 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
945 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
946 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
947 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
948 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
949 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
950 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
951 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
952 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
953 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
954 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
955 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
956 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
957 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
958 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
959 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
960 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
961 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
962 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
963 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
964 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
965 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
966 arm_sme::aarch64_sme_cntsd>();
967 target.addLegalDialect<arith::ArithDialect,
968 /* The following are used to lower tile spills/fills */
969 vector::VectorDialect, scf::SCFDialect,
970 memref::MemRefDialect>();
971 // Pseudo operations. These cannot be code-generated but may exist in the
972 // input IR, or be generated during the conversion. They need to be eliminated
973 // before the final conversion to LLVM IR (and likely will be due to DCE).
974 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
975 UnrealizedConversionCastOp>();
978 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
979 RewritePatternSet &patterns) {
980 converter.addConversion([&](VectorType type) -> std::optional<Type> {
981 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
982 // SME vector types should be eliminated.
983 if (arm_sme::isValidSMETileVectorType(type))
984 return type;
985 return std::nullopt;
988 addArmSMEConversionPatterns<
989 LoadTileSliceConversion, ExtractTileSliceConversion,
990 InsertTileSliceConversion, StoreTileSliceConversion,
991 StreamingVLOpConversion, OuterProductOpConversion,
992 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
993 arm_sme::aarch64_sme_mopa_wide>,
994 OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
995 arm_sme::aarch64_sme_mops_wide>,
996 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
997 arm_sme::aarch64_sme_smopa_za32>,
998 OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
999 arm_sme::aarch64_sme_smops_za32>,
1000 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1001 arm_sme::aarch64_sme_umopa_za32>,
1002 OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1003 arm_sme::aarch64_sme_umops_za32>,
1004 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1005 arm_sme::aarch64_sme_smopa_wide>,
1006 OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1007 arm_sme::aarch64_sme_smops_wide>,
1008 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1009 arm_sme::aarch64_sme_umopa_wide>,
1010 OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1011 arm_sme::aarch64_sme_umops_wide>,
1012 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1013 arm_sme::aarch64_sme_sumopa_wide>,
1014 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1015 arm_sme::aarch64_sme_sumops_wide>,
1016 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1017 arm_sme::aarch64_sme_usmopa_wide>,
1018 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1019 arm_sme::aarch64_sme_usmops_wide>,
1020 ZeroOpConversion>(patterns, converter);
1023 std::unique_ptr<Pass>
1024 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1025 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);