1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
20 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/ScopeExit.h"
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
39 static constexpr StringLiteral
kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42 static Operation
*createLoadTileSliceIntrinsic(
43 RewriterBase
&rewriter
, Location loc
, arm_sme::ArmSMETileType type
,
44 arm_sme::TileSliceLayout layout
, Value maskOp
, Value ptr
,
45 IntegerAttr tileId
, Value tileSliceI32
) {
46 if (layout
== arm_sme::TileSliceLayout::Horizontal
) {
48 case arm_sme::ArmSMETileType::ZAB
:
49 return rewriter
.create
<arm_sme::aarch64_sme_ld1b_horiz
>(
50 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
51 case arm_sme::ArmSMETileType::ZAH
:
52 return rewriter
.create
<arm_sme::aarch64_sme_ld1h_horiz
>(
53 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
54 case arm_sme::ArmSMETileType::ZAS
:
55 return rewriter
.create
<arm_sme::aarch64_sme_ld1w_horiz
>(
56 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
57 case arm_sme::ArmSMETileType::ZAD
:
58 return rewriter
.create
<arm_sme::aarch64_sme_ld1d_horiz
>(
59 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
60 case arm_sme::ArmSMETileType::ZAQ
:
61 return rewriter
.create
<arm_sme::aarch64_sme_ld1q_horiz
>(
62 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
66 case arm_sme::ArmSMETileType::ZAB
:
67 return rewriter
.create
<arm_sme::aarch64_sme_ld1b_vert
>(
68 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
69 case arm_sme::ArmSMETileType::ZAH
:
70 return rewriter
.create
<arm_sme::aarch64_sme_ld1h_vert
>(
71 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
72 case arm_sme::ArmSMETileType::ZAS
:
73 return rewriter
.create
<arm_sme::aarch64_sme_ld1w_vert
>(
74 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
75 case arm_sme::ArmSMETileType::ZAD
:
76 return rewriter
.create
<arm_sme::aarch64_sme_ld1d_vert
>(
77 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
78 case arm_sme::ArmSMETileType::ZAQ
:
79 return rewriter
.create
<arm_sme::aarch64_sme_ld1q_vert
>(
80 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
84 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
87 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
88 static Operation
*createStoreTileSliceIntrinsic(
89 RewriterBase
&rewriter
, Location loc
, arm_sme::ArmSMETileType type
,
90 arm_sme::TileSliceLayout layout
, Value maskOp
, Value ptr
,
91 IntegerAttr tileId
, Value tileSliceI32
) {
92 if (layout
== arm_sme::TileSliceLayout::Horizontal
) {
94 case arm_sme::ArmSMETileType::ZAB
:
95 return rewriter
.create
<arm_sme::aarch64_sme_st1b_horiz
>(
96 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
97 case arm_sme::ArmSMETileType::ZAH
:
98 return rewriter
.create
<arm_sme::aarch64_sme_st1h_horiz
>(
99 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
100 case arm_sme::ArmSMETileType::ZAS
:
101 return rewriter
.create
<arm_sme::aarch64_sme_st1w_horiz
>(
102 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
103 case arm_sme::ArmSMETileType::ZAD
:
104 return rewriter
.create
<arm_sme::aarch64_sme_st1d_horiz
>(
105 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
106 case arm_sme::ArmSMETileType::ZAQ
:
107 return rewriter
.create
<arm_sme::aarch64_sme_st1q_horiz
>(
108 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
112 case arm_sme::ArmSMETileType::ZAB
:
113 return rewriter
.create
<arm_sme::aarch64_sme_st1b_vert
>(
114 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
115 case arm_sme::ArmSMETileType::ZAH
:
116 return rewriter
.create
<arm_sme::aarch64_sme_st1h_vert
>(
117 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
118 case arm_sme::ArmSMETileType::ZAS
:
119 return rewriter
.create
<arm_sme::aarch64_sme_st1w_vert
>(
120 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
121 case arm_sme::ArmSMETileType::ZAD
:
122 return rewriter
.create
<arm_sme::aarch64_sme_st1d_vert
>(
123 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
124 case arm_sme::ArmSMETileType::ZAQ
:
125 return rewriter
.create
<arm_sme::aarch64_sme_st1q_vert
>(
126 loc
, maskOp
, ptr
, tileId
, tileSliceI32
);
129 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
132 IntegerAttr
getTileIdOrError(arm_sme::ArmSMETileOpInterface op
) {
133 auto tileId
= op
.getTileId();
136 "expected tile ID to be allocated before conversion to LLVM");
140 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
141 /// placed in the first block of the function.
142 static memref::AllocaOp
143 createAllocaForTile(RewriterBase
&rewriter
, Location loc
,
144 FunctionOpInterface func
,
145 arm_sme::ArmSMETileOpInterface tileOp
) {
146 RewriterBase::InsertionGuard
g(rewriter
);
147 // Move to the first operation in the function.
148 rewriter
.setInsertionPointToStart(&func
.getBlocks().front());
149 // Create an alloca matching the tile size of the `tileOp`.
150 auto vscale
= rewriter
.create
<vector::VectorScaleOp
>(loc
);
151 auto tileElementType
= tileOp
.getTileType().getElementType();
152 auto memrefType
= MemRefType::get(
153 {ShapedType::kDynamic
, ShapedType::kDynamic
}, tileElementType
);
154 unsigned minElements
= arm_sme::getSMETileSliceMinNumElts(tileElementType
);
156 rewriter
.create
<arith::ConstantIndexOp
>(loc
, minElements
);
157 auto vectorLen
= rewriter
.create
<arith::MulIOp
>(loc
, vscale
, minElementsOp
);
158 auto alloca
= rewriter
.create
<memref::AllocaOp
>(
159 loc
, memrefType
, ValueRange
{vectorLen
, vectorLen
});
163 /// Finds or creates an alloca for a spill of a tile.
164 static memref::AllocaOp
getOrCreateAllocaForTile(
165 RewriterBase
&rewriter
, Location loc
, FunctionOpInterface func
,
166 arm_sme::ArmSMETileOpInterface tileOp
, unsigned tileId
) {
167 // Find an alloca at the top of the function tagged with a
168 // 'arm_sme.in_memory_tile_id' that matches `tileId`.
169 for (auto &op
: func
.getBlocks().front()) {
170 auto alloca
= llvm::dyn_cast
<memref::AllocaOp
>(op
);
173 auto inMemoryTileId
= llvm::dyn_cast_or_null
<IntegerAttr
>(
174 alloca
->getDiscardableAttr(kInMemoryTileIdAttr
));
177 if (inMemoryTileId
.getInt() == tileId
)
180 // Otherwise, create a new alloca:
181 auto alloca
= createAllocaForTile(rewriter
, loc
, func
, tileOp
);
182 alloca
->setDiscardableAttr(kInMemoryTileIdAttr
,
183 rewriter
.getI32IntegerAttr(tileId
));
187 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
188 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
189 /// the op to tile 0, then emitting a full tile swap between ZA and memory
190 /// before + after the tile op.
194 /// // Note: <IN MEMORY TILE> = tile ID >= 16.
195 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
198 /// // At function entry:
199 /// %spill = memref.alloca ... : memref<?x?xty>
202 /// scf.for %slice_idx {
203 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
204 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
205 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
207 /// arm_sme.tile_op { tile_id = 0 }
208 /// scf.for %slice_idx {
209 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
210 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
211 /// vector.store %slice_to_save, %spill[%slice_idx, %c0]
214 /// Note that these spills/fills are not inserted earlier as concept of a
215 /// register, and the need to swap the contents, can't really be represented
216 /// correctly at a high level in MLIR.
218 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
219 /// redundant reloads). This could be done via a method on the
220 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
222 /// `tileOp.getZaUsage()` could return:
224 /// struct ArmSMEOpZAUsage {
225 /// enum class Kind {
226 /// TileRead, // Omit store after tile operation.
227 /// TileWrite, // Omit load before tile operation.
228 /// TileReadWrite, // Needs both tile load and store.
229 /// SliceRead, // Spill single slice and omit store after operation.
230 /// SliceWrite, // Spill single slice and omit load before operation.
231 /// SliceReadWrite // Spill single slice.
233 /// Value sliceIndex {};
234 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
237 struct ConvertArmSMESpillsAndFillsToLLVM
: public ConvertToLLVMPattern
{
239 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName
,
240 const LLVMTypeConverter
&typeConverter
,
241 PatternBenefit benefit
)
242 : ConvertToLLVMPattern(rootOpName
, &typeConverter
.getContext(),
243 typeConverter
, benefit
) {}
246 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
247 ConversionPatternRewriter
&rewriter
) const override
{
248 auto tileOp
= cast
<arm_sme::ArmSMETileOpInterface
>(op
);
249 // Tile has a real (hardware) tile. No spills/reloads required.
250 if (!tileOp
.isInMemoryTile())
254 "failed to allocate SME virtual tile to operation, tile value will go "
255 "through memory, expect degraded performance");
257 // Step 1. Create an alloca for the tile at the top of the function (if one
258 // does not already exist).
259 auto loc
= tileOp
.getLoc();
260 auto func
= tileOp
->getParentOfType
<FunctionOpInterface
>();
261 auto tileAlloca
= getOrCreateAllocaForTile(rewriter
, loc
, func
, tileOp
,
262 tileOp
.getTileId().getInt());
264 // Step 2. Assign the op a real tile ID.
265 // For simplicity, we always use tile 0 (which always exists).
266 auto zeroTileId
= rewriter
.getI32IntegerAttr(0);
267 rewriter
.modifyOpInPlace(tileOp
, [&] { tileOp
.setTileId(zeroTileId
); });
269 VectorType tileVectorType
= tileOp
.getTileType();
270 auto sliceType
= VectorType::Builder(tileVectorType
).dropDim(0);
271 auto swapInMemoryTileWithSMETileZero
= [&] {
272 emitFullTileSwap(rewriter
, loc
, tileAlloca
,
273 *arm_sme::getSMETileType(tileVectorType
), sliceType
,
277 // Step 3. Emit tile swaps before and after the op.
278 // TODO: Reduce the amount spilled to the amount of data the `tileOp`
279 // touches (i.e. a single tile slice).
281 rewriter
.setInsertionPoint(op
);
282 // Swap the contents of ZA and the in-memory tile before the op.
283 swapInMemoryTileWithSMETileZero();
284 rewriter
.setInsertionPointAfter(op
);
285 // Swap the tile back out to memory again after the op.
286 swapInMemoryTileWithSMETileZero();
292 /// Extracts a pointer to a slice of an in-memory tile.
293 Value
getInMemoryTileSlicePtr(RewriterBase
&rewriter
, Location loc
,
294 Value tileMemory
, Value sliceIndex
) const {
295 auto llvmType
= getTypeConverter()->convertType(tileMemory
.getType());
297 rewriter
.create
<UnrealizedConversionCastOp
>(loc
, llvmType
, tileMemory
);
298 auto zero
= rewriter
.create
<arith::ConstantIntOp
>(loc
, 0, /*width=*/64);
299 auto sliceIndexI64
= rewriter
.create
<arith::IndexCastOp
>(
300 loc
, rewriter
.getI64Type(), sliceIndex
);
301 return getStridedElementPtr(
302 loc
, llvm::cast
<MemRefType
>(tileMemory
.getType()),
303 descriptor
.getResult(0), {sliceIndexI64
, zero
},
304 static_cast<ConversionPatternRewriter
&>(rewriter
));
307 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
308 /// tile-sized memref (`tileAlloca`).
309 void emitSliceSwap(RewriterBase
&rewriter
, Location loc
, Value tileAlloca
,
310 arm_sme::ArmSMETileType tileType
, VectorType sliceType
,
311 IntegerAttr tileId
, Value sliceIndex
) const {
312 // Cast the slice index to an i32.
313 auto sliceIndexI32
= rewriter
.create
<arith::IndexCastOp
>(
314 loc
, rewriter
.getI32Type(), sliceIndex
);
315 // Create an all-true predicate for the slice.
316 auto predicateType
= sliceType
.clone(rewriter
.getI1Type());
317 auto allTruePredicate
= rewriter
.create
<arith::ConstantOp
>(
318 loc
, DenseElementsAttr::get(predicateType
, true));
319 // Create padding vector (never used due to all-true predicate).
320 auto padVector
= rewriter
.create
<LLVM::UndefOp
>(loc
, sliceType
);
321 // Get a pointer to the current slice.
323 getInMemoryTileSlicePtr(rewriter
, loc
, tileAlloca
, sliceIndex
);
324 // Read the value of the current slice from ZA.
325 auto currentTileSlice
= rewriter
.create
<arm_sme::aarch64_sme_read_horiz
>(
326 loc
, sliceType
, padVector
, allTruePredicate
, tileId
, sliceIndexI32
);
327 // Load the new tile slice back from memory into ZA.
328 createLoadTileSliceIntrinsic(
329 rewriter
, loc
, tileType
, arm_sme::TileSliceLayout::Horizontal
,
330 allTruePredicate
, slicePtr
, tileId
, sliceIndexI32
);
331 // Store the current tile slice to memory.
332 auto zero
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
333 rewriter
.create
<vector::StoreOp
>(loc
, currentTileSlice
, tileAlloca
,
334 ValueRange
{sliceIndex
, zero
});
337 /// Emits a full in-place swap of the contents of a tile in ZA and a
338 /// tile-sized memref (`tileAlloca`).
339 void emitFullTileSwap(RewriterBase
&rewriter
, Location loc
, Value tileAlloca
,
340 arm_sme::ArmSMETileType tileType
, VectorType sliceType
,
341 IntegerAttr tileId
) const {
342 RewriterBase::InsertionGuard
guard(rewriter
);
343 // Create an scf.for over all tile slices.
345 rewriter
.create
<arith::ConstantIndexOp
>(loc
, sliceType
.getDimSize(0));
346 auto lowerBound
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 0);
347 auto upperBound
= rewriter
.create
<arith::MulIOp
>(
348 loc
, minNumElts
, rewriter
.create
<vector::VectorScaleOp
>(loc
));
349 auto step
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
350 auto forOp
= rewriter
.create
<scf::ForOp
>(loc
, lowerBound
, upperBound
, step
);
351 // Emit a swap for each tile slice.
352 rewriter
.setInsertionPointToStart(forOp
.getBody());
353 auto sliceIndex
= forOp
.getInductionVar();
354 emitSliceSwap(rewriter
, loc
, tileAlloca
, tileType
, sliceType
, tileId
,
359 enum class RequiresSpillsAndFills
{ Yes
, No
};
361 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
362 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
363 /// disabled by setting the `requiresSpillsAndFills` template parameter to
364 /// `RequiresSpillsAndFills::No`.
365 template <typename SourceOp
, RequiresSpillsAndFills requiresSpillsAndFills
=
366 RequiresSpillsAndFills::Yes
>
367 struct ConvertArmSMEOpToLLVMPattern
: ConvertOpToLLVMPattern
<SourceOp
> {
368 using ArmSMEOp
= SourceOp
;
369 using ConvertOpToLLVMPattern
<SourceOp
>::ConvertOpToLLVMPattern
;
371 static constexpr bool requiresSpillsAndFillsConversion() {
372 return requiresSpillsAndFills
== RequiresSpillsAndFills::Yes
;
376 template <typename Pattern
>
377 static void addArmSMEConversionPattern(RewritePatternSet
&patterns
,
378 LLVMTypeConverter
const &typeConverter
) {
379 // Register spills/fills for ops that implement the
380 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
381 // `RequiresSpillsAndFills::Yes`.
382 if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383 std::is_base_of_v
<arm_sme::ArmSMETileOpInterface::Trait
<
384 typename
Pattern::ArmSMEOp
>,
385 typename
Pattern::ArmSMEOp
>) {
386 // Add spill/fill conversions with a very high benefit to ensure
387 // they are lowered first.
388 patterns
.add
<ConvertArmSMESpillsAndFillsToLLVM
>(
389 Pattern::ArmSMEOp::getOperationName(), typeConverter
,
392 patterns
.add
<Pattern
>(typeConverter
);
395 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
396 template <typename
... Patterns
>
398 addArmSMEConversionPatterns(RewritePatternSet
&patterns
,
399 LLVMTypeConverter
const &typeConverter
) {
400 (addArmSMEConversionPattern
<Patterns
>(patterns
, typeConverter
), ...);
403 /// Lower 'arm_sme.zero' to SME intrinsics.
407 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
412 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
413 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
416 /// The 'arm_sme.get_tile' (which models the return) will fold away once all
417 /// ArmSME ops have been converted to LLVM intrinsics.
418 struct ZeroOpConversion
: public ConvertArmSMEOpToLLVMPattern
<arm_sme::ZeroOp
> {
419 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
422 matchAndRewrite(arm_sme::ZeroOp zero
, OpAdaptor adaptor
,
423 ConversionPatternRewriter
&rewriter
) const override
{
424 auto loc
= zero
.getLoc();
426 auto tileId
= getTileIdOrError(zero
);
430 // Get the base mask for tile based on the element size.
431 // The base mask is just the mask to zero the first tile (of a size).
432 // These masks are derived from:
433 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
434 arm_sme::ArmSMETileType tileType
=
435 *arm_sme::getSMETileType(zero
.getTileType());
436 auto baseMaskForSize
= [&] {
438 case arm_sme::ArmSMETileType::ZAB
:
439 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
440 // 64-bit element tiles named ZA0.D to ZA7.D.
442 case arm_sme::ArmSMETileType::ZAH
:
443 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
444 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
447 case arm_sme::ArmSMETileType::ZAS
:
448 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
449 // element tiles named ZA0.D and ZA4.D.
450 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
452 case arm_sme::ArmSMETileType::ZAD
:
453 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
454 // setting the bit for that tile.
457 llvm_unreachable("bad element size");
461 // The actual mask is just the base mask shifted by the tile ID.
462 // This will be folded to a constant after tile allocation.
464 // The shift is just derived from the layout of the tiles, and that the tile
465 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
468 // ZA0.S = ZA0.D and ZA4.D
470 // * Mask -> 00010001 = (00010001 << 0)
471 // ZA1.S = ZA1.D and ZA5.D
473 // * Mask -> 00100010 = (00010001 << 1)
474 // ZA2.S = ZA2.D and ZA6.D
476 // * Mask -> 01000100 = (00010001 << 2)
477 // ZA3.S = ZA3.D and ZA7.D
479 // * Mask -> 10001000 = (00010001 << 3)
481 // This holds for all tile sizes.
482 int32_t zeroMask
= baseMaskForSize
<< int32_t(tileId
.getInt());
483 rewriter
.create
<arm_sme::aarch64_sme_zero
>(
484 loc
, rewriter
.getI32IntegerAttr(zeroMask
));
486 // Create a placeholder op to preserve dataflow.
487 // Note: Place the `get_tile` op at the start of the block. This ensures
488 // that if there are multiple `zero` ops the intrinsics will be consecutive.
489 rewriter
.setInsertionPointToStart(zero
->getBlock());
490 rewriter
.replaceOpWithNewOp
<arm_sme::GetTileOp
>(zero
, zero
.getVectorType());
496 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
497 struct LoadTileSliceConversion
498 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::LoadTileSliceOp
> {
499 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
502 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp
,
503 arm_sme::LoadTileSliceOp::Adaptor adaptor
,
504 ConversionPatternRewriter
&rewriter
) const override
{
505 auto loc
= loadTileSliceOp
.getLoc();
506 auto tileId
= getTileIdOrError(loadTileSliceOp
);
510 Value ptr
= this->getStridedElementPtr(loc
, loadTileSliceOp
.getMemRefType(),
512 adaptor
.getIndices(), rewriter
);
514 auto tileSlice
= loadTileSliceOp
.getTileSliceIndex();
516 // Cast tile slice to i32 for intrinsic.
517 auto tileSliceI32
= rewriter
.create
<arith::IndexCastUIOp
>(
518 loc
, rewriter
.getI32Type(), tileSlice
);
520 // Create all active predicate mask.
521 auto maskOp
= loadTileSliceOp
.getMask();
523 auto tileVectorType
= loadTileSliceOp
.getVectorType();
524 arm_sme::ArmSMETileType tileType
= *arm_sme::getSMETileType(tileVectorType
);
525 arm_sme::TileSliceLayout layout
= loadTileSliceOp
.getLayout();
527 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
528 createLoadTileSliceIntrinsic(rewriter
, loc
, tileType
, layout
, maskOp
, ptr
,
529 tileId
, tileSliceI32
);
531 // The load intrinsics have no result, replace 'arm_sme.tile_load' with
532 // the input tile to preserve dataflow.
533 rewriter
.replaceOp(loadTileSliceOp
, loadTileSliceOp
.getTile());
539 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
540 struct StoreTileSliceConversion
541 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::StoreTileSliceOp
> {
542 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
545 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp
,
546 arm_sme::StoreTileSliceOp::Adaptor adaptor
,
547 ConversionPatternRewriter
&rewriter
) const override
{
548 auto loc
= storeTileSliceOp
.getLoc();
549 auto tileVectorType
= storeTileSliceOp
.getVectorType();
551 auto tileId
= getTileIdOrError(storeTileSliceOp
);
555 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556 Value ptr
= this->getStridedElementPtr(
557 loc
, storeTileSliceOp
.getMemRefType(), adaptor
.getBase(),
558 adaptor
.getIndices(), rewriter
);
560 auto tileSlice
= storeTileSliceOp
.getTileSliceIndex();
562 // Cast tile slice to i32 for intrinsic.
563 auto tileSliceI32
= rewriter
.create
<arith::IndexCastUIOp
>(
564 loc
, rewriter
.getI32Type(), tileSlice
);
566 auto maskOp
= storeTileSliceOp
.getMask();
568 arm_sme::TileSliceLayout layout
= storeTileSliceOp
.getLayout();
569 arm_sme::ArmSMETileType tileType
= *arm_sme::getSMETileType(tileVectorType
);
571 rewriter
.replaceOp(storeTileSliceOp
,
572 createStoreTileSliceIntrinsic(rewriter
, loc
, tileType
,
574 tileId
, tileSliceI32
));
580 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
581 struct InsertTileSliceConversion
582 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::InsertTileSliceOp
> {
583 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
586 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp
,
587 arm_sme::InsertTileSliceOp::Adaptor adaptor
,
588 ConversionPatternRewriter
&rewriter
) const override
{
589 auto loc
= insertTileSliceOp
.getLoc();
590 auto tileType
= insertTileSliceOp
.getTileType();
592 auto tileId
= getTileIdOrError(insertTileSliceOp
);
596 auto tileSlice
= insertTileSliceOp
.getTileSliceIndex();
598 // Cast tile slice from index to i32 for intrinsic.
599 auto tileSliceI32
= rewriter
.create
<arith::IndexCastUIOp
>(
600 loc
, rewriter
.getI32Type(), tileSlice
);
602 // Create all active predicate mask.
603 auto one
= rewriter
.create
<arith::ConstantOp
>(
604 loc
, rewriter
.getI1Type(),
605 rewriter
.getIntegerAttr(rewriter
.getI1Type(), 1));
606 auto predTy
= VectorType::get(tileType
.getShape()[0], rewriter
.getI1Type(),
607 /*scalableDims=*/{true});
608 auto allActiveMask
= rewriter
.create
<vector::SplatOp
>(loc
, predTy
, one
);
610 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
611 switch (insertTileSliceOp
.getLayout()) {
612 case arm_sme::TileSliceLayout::Horizontal
:
613 rewriter
.create
<arm_sme::aarch64_sme_write_horiz
>(
614 loc
, tileId
, tileSliceI32
, allActiveMask
,
615 insertTileSliceOp
.getVector());
617 case arm_sme::TileSliceLayout::Vertical
:
618 rewriter
.create
<arm_sme::aarch64_sme_write_vert
>(
619 loc
, tileId
, tileSliceI32
, allActiveMask
,
620 insertTileSliceOp
.getVector());
624 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
625 // the input tile to preserve dataflow.
626 rewriter
.replaceOp(insertTileSliceOp
, insertTileSliceOp
.getTile());
632 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
633 struct ExtractTileSliceConversion
634 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::ExtractTileSliceOp
> {
635 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
638 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice
, OpAdaptor
,
639 ConversionPatternRewriter
&rewriter
) const override
{
640 auto loc
= extractTileSlice
.getLoc();
641 auto sliceType
= extractTileSlice
.getSliceType();
642 auto sliceIndex
= extractTileSlice
.getTileSliceIndex();
644 auto tileId
= getTileIdOrError(extractTileSlice
);
648 // Create an 'all true' predicate for the tile slice.
649 auto predicateType
= sliceType
.cloneWith({}, rewriter
.getI1Type());
650 auto allTruePredicate
= rewriter
.create
<arith::ConstantOp
>(
651 loc
, DenseElementsAttr::get(predicateType
, true));
653 // Zero destination/fallback for tile slice extraction.
654 auto zeroVector
= rewriter
.create
<arith::ConstantOp
>(
655 loc
, sliceType
, rewriter
.getZeroAttr(sliceType
));
657 // Cast tile slice from index to i32 for intrinsic.
658 auto sliceIndexI32
= rewriter
.create
<arith::IndexCastOp
>(
659 loc
, rewriter
.getI32Type(), sliceIndex
);
661 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662 switch (extractTileSlice
.getLayout()) {
663 case arm_sme::TileSliceLayout::Horizontal
:
664 rewriter
.replaceOpWithNewOp
<arm_sme::aarch64_sme_read_horiz
>(
665 extractTileSlice
, sliceType
, zeroVector
, allTruePredicate
, tileId
,
668 case arm_sme::TileSliceLayout::Vertical
:
669 rewriter
.replaceOpWithNewOp
<arm_sme::aarch64_sme_read_vert
>(
670 extractTileSlice
, sliceType
, zeroVector
, allTruePredicate
, tileId
,
679 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
683 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
684 /// : vector<[4]xf32>, vector<[4]xf32>
688 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
689 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
690 /// vector<[4]xf32>) -> ()
692 /// Currently only supports FMOPA and BFMOPA (non-widening).
693 struct OuterProductOpConversion
694 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::OuterProductOp
> {
695 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
698 matchAndRewrite(arm_sme::OuterProductOp outerProductOp
,
699 arm_sme::OuterProductOp::Adaptor adaptor
,
700 ConversionPatternRewriter
&rewriter
) const override
{
701 auto tileId
= getTileIdOrError(outerProductOp
);
705 auto isSupportedType
= [](VectorType vectorType
) {
706 // TODO: the FP outer product instruction variants are predicated on
707 // different features [1]:
709 // * FMOPA (non-widening)
710 // * half-precision - +sme2p1,+sme-f16f16
711 // * single-precision - +sme
712 // * double-precision - +sme-f64f64
714 // * half-precision - +sme2p1,+b16b16
716 // It should be possible to control lowering based on target features.
718 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
719 if ((vectorType
.getRank() != 2) || !vectorType
.allDimsScalable())
722 auto elementType
= vectorType
.getElementType();
724 if (!elementType
.isF16() && !elementType
.isBF16() &&
725 !elementType
.isF32() && !elementType
.isF64())
728 unsigned minNumElts
= arm_sme::MinStreamingVectorLengthInBits
/
729 vectorType
.getElementTypeBitWidth();
730 return vectorType
.getShape() ==
731 ArrayRef
<int64_t>({minNumElts
, minNumElts
});
734 // TODO: Support CombiningKind::Sub for outer products.
735 if (outerProductOp
.getKind() != arm_sme::CombiningKind::Add
)
736 return outerProductOp
.emitError("unsupported kind");
738 auto resultVectorType
= outerProductOp
.getResultType();
739 if (!isSupportedType(resultVectorType
))
740 return outerProductOp
.emitError("unsupported type");
742 auto loc
= outerProductOp
.getLoc();
744 Value acc
= outerProductOp
.getAcc();
746 // Initalize accumulator with zero.
747 auto zero
= rewriter
.create
<arm_sme::ZeroOp
>(loc
, resultVectorType
);
748 zero
.setTileId(tileId
);
752 Value lhsMask
= outerProductOp
.getLhsMask();
753 Value rhsMask
= outerProductOp
.getRhsMask();
755 if (!lhsMask
|| !rhsMask
) {
757 outerProductOp
.getLhsType().cloneWith({}, rewriter
.getI1Type());
758 Value allActiveMask
= rewriter
.create
<arith::ConstantOp
>(
759 loc
, DenseElementsAttr::get(predTy
, true));
760 lhsMask
= allActiveMask
;
761 rhsMask
= allActiveMask
;
764 // Create 'arm_sme.intr.mopa' outer product intrinsic.
765 rewriter
.create
<arm_sme::aarch64_sme_mopa
>(loc
, tileId
, lhsMask
, rhsMask
,
766 outerProductOp
.getLhs(),
767 outerProductOp
.getRhs());
769 // The outerproduct intrinsics have no result, replace
770 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
771 rewriter
.replaceOp(outerProductOp
, acc
);
777 /// Lower 2-way and 4-way widening outer products to intrinsics.
778 template <class OuterProductWideningOp
, class OuterProductWideningIntrOp
>
779 struct OuterProductWideningOpConversion
780 : public ConvertArmSMEOpToLLVMPattern
<OuterProductWideningOp
> {
781 using ConvertArmSMEOpToLLVMPattern
<
782 OuterProductWideningOp
>::ConvertArmSMEOpToLLVMPattern
;
785 matchAndRewrite(OuterProductWideningOp op
,
786 typename
OuterProductWideningOp::Adaptor adaptor
,
787 ConversionPatternRewriter
&rewriter
) const override
{
788 auto tileId
= getTileIdOrError(op
);
792 auto loc
= op
.getLoc();
793 Value acc
= op
.getAcc();
795 // Initalize accumulator with zero.
796 auto zero
= rewriter
.create
<arm_sme::ZeroOp
>(loc
, op
.getResultType());
797 zero
.setTileId(tileId
);
801 Value lhsMask
= op
.getLhsMask();
802 Value rhsMask
= op
.getRhsMask();
803 if (!lhsMask
|| !rhsMask
) {
804 auto predTy
= op
.getLhsType().cloneWith({}, rewriter
.getI1Type());
805 Value allActiveMask
= rewriter
.create
<arith::ConstantOp
>(
806 loc
, DenseElementsAttr::get(predTy
, true));
807 lhsMask
= allActiveMask
;
808 rhsMask
= allActiveMask
;
811 rewriter
.create
<OuterProductWideningIntrOp
>(
812 loc
, tileId
, lhsMask
, rhsMask
, adaptor
.getLhs(), adaptor
.getRhs());
814 // The outerproduct intrinsics have no result, replace
815 // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816 rewriter
.replaceOp(op
, acc
);
822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
826 /// %0 = arm_sme.streaming_vl <half>
830 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64
831 /// %0 = arith.index_cast %cnt : i64 to index
833 struct StreamingVLOpConversion
834 : public ConvertArmSMEOpToLLVMPattern
<arm_sme::StreamingVLOp
,
835 RequiresSpillsAndFills::No
> {
836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern
;
839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp
,
840 arm_sme::StreamingVLOp::Adaptor adaptor
,
841 ConversionPatternRewriter
&rewriter
) const override
{
842 auto loc
= streamingVlOp
.getLoc();
843 auto i64Type
= rewriter
.getI64Type();
844 auto *intrOp
= [&]() -> Operation
* {
845 switch (streamingVlOp
.getTypeSize()) {
846 case arm_sme::TypeSize::Byte
:
847 return rewriter
.create
<arm_sme::aarch64_sme_cntsb
>(loc
, i64Type
);
848 case arm_sme::TypeSize::Half
:
849 return rewriter
.create
<arm_sme::aarch64_sme_cntsh
>(loc
, i64Type
);
850 case arm_sme::TypeSize::Word
:
851 return rewriter
.create
<arm_sme::aarch64_sme_cntsw
>(loc
, i64Type
);
852 case arm_sme::TypeSize::Double
:
853 return rewriter
.create
<arm_sme::aarch64_sme_cntsd
>(loc
, i64Type
);
855 llvm_unreachable("unknown type size in StreamingVLOpConversion");
857 rewriter
.replaceOpWithNewOp
<arith::IndexCastOp
>(
858 streamingVlOp
, rewriter
.getIndexType(), intrOp
->getResult(0));
863 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
864 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
865 static void mergeConsecutiveTileZerosInBlock(Block
*block
) {
866 uint32_t mergedZeroMask
= 0;
867 SmallVector
<arm_sme::aarch64_sme_zero
, 16> zeroOpsToMerge
;
868 auto replaceMergedZeroOps
= [&] {
869 auto cleanup
= llvm::make_scope_exit([&] {
871 zeroOpsToMerge
.clear();
873 if (zeroOpsToMerge
.size() <= 1)
875 IRRewriter
rewriter(zeroOpsToMerge
.front());
876 rewriter
.create
<arm_sme::aarch64_sme_zero
>(
877 zeroOpsToMerge
.front().getLoc(),
878 rewriter
.getI32IntegerAttr(mergedZeroMask
));
879 for (auto zeroOp
: zeroOpsToMerge
)
880 rewriter
.eraseOp(zeroOp
);
882 for (Operation
&op
: *block
) {
883 if (auto zeroOp
= dyn_cast
<arm_sme::aarch64_sme_zero
>(op
)) {
884 mergedZeroMask
|= zeroOp
.getTileMask();
885 zeroOpsToMerge
.push_back(zeroOp
);
887 replaceMergedZeroOps();
890 replaceMergedZeroOps();
897 struct ConvertArmSMEToLLVMPass
898 : public impl::ConvertArmSMEToLLVMBase
<ConvertArmSMEToLLVMPass
> {
899 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges
) {
900 this->dumpTileLiveRanges
= dumpTileLiveRanges
;
902 void runOnOperation() override
{
903 auto function
= getOperation();
905 if (failed(arm_sme::allocateSMETiles(function
, dumpTileLiveRanges
)))
906 return signalPassFailure();
908 LLVMConversionTarget
target(getContext());
909 RewritePatternSet
patterns(&getContext());
910 LLVMTypeConverter
converter(&getContext());
911 configureArmSMEToLLVMConversionLegality(target
);
912 populateArmSMEToLLVMConversionPatterns(converter
, patterns
);
914 if (failed(applyPartialConversion(function
, target
, std::move(patterns
))))
917 function
->walk(mergeConsecutiveTileZerosInBlock
);
919 // Walk the function and fail if there are unexpected operations on SME
920 // tile types after conversion.
921 function
->walk([&](Operation
*op
) {
922 // These ops are legal post conversion, skip these.
923 if (isa
<arm_sme::CopyTileOp
, arm_sme::GetTileOp
, cf::BranchOp
>(op
) ||
926 auto isSMETileType
= [](Type type
) {
927 return arm_sme::isValidSMETileVectorType(type
);
929 if (llvm::any_of(op
->getResultTypes(), isSMETileType
) ||
930 llvm::any_of(op
->getOperandTypes(), isSMETileType
)) {
931 op
->emitOpError("unexpected operation with SME tile type after "
932 "conversion to LLVM");
941 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget
&target
) {
942 target
.addIllegalDialect
<arm_sme::ArmSMEDialect
>();
944 arm_sme::aarch64_sme_zero
, arm_sme::aarch64_sme_str
,
945 arm_sme::aarch64_sme_ld1b_horiz
, arm_sme::aarch64_sme_ld1h_horiz
,
946 arm_sme::aarch64_sme_ld1w_horiz
, arm_sme::aarch64_sme_ld1d_horiz
,
947 arm_sme::aarch64_sme_ld1q_horiz
, arm_sme::aarch64_sme_st1b_horiz
,
948 arm_sme::aarch64_sme_st1h_horiz
, arm_sme::aarch64_sme_st1w_horiz
,
949 arm_sme::aarch64_sme_st1d_horiz
, arm_sme::aarch64_sme_st1q_horiz
,
950 arm_sme::aarch64_sme_ld1b_vert
, arm_sme::aarch64_sme_ld1h_vert
,
951 arm_sme::aarch64_sme_ld1w_vert
, arm_sme::aarch64_sme_ld1d_vert
,
952 arm_sme::aarch64_sme_ld1q_vert
, arm_sme::aarch64_sme_st1b_vert
,
953 arm_sme::aarch64_sme_st1h_vert
, arm_sme::aarch64_sme_st1w_vert
,
954 arm_sme::aarch64_sme_st1d_vert
, arm_sme::aarch64_sme_st1q_vert
,
955 arm_sme::aarch64_sme_read_horiz
, arm_sme::aarch64_sme_read_vert
,
956 arm_sme::aarch64_sme_write_horiz
, arm_sme::aarch64_sme_write_vert
,
957 arm_sme::aarch64_sme_mopa
, arm_sme::aarch64_sme_mopa_wide
,
958 arm_sme::aarch64_sme_mops_wide
, arm_sme::aarch64_sme_smopa_wide
,
959 arm_sme::aarch64_sme_smops_wide
, arm_sme::aarch64_sme_umopa_wide
,
960 arm_sme::aarch64_sme_umops_wide
, arm_sme::aarch64_sme_smopa_za32
,
961 arm_sme::aarch64_sme_smops_za32
, arm_sme::aarch64_sme_umopa_za32
,
962 arm_sme::aarch64_sme_umops_za32
, arm_sme::aarch64_sme_sumopa_wide
,
963 arm_sme::aarch64_sme_sumops_wide
, arm_sme::aarch64_sme_usmopa_wide
,
964 arm_sme::aarch64_sme_usmops_wide
, arm_sme::aarch64_sme_cntsb
,
965 arm_sme::aarch64_sme_cntsh
, arm_sme::aarch64_sme_cntsw
,
966 arm_sme::aarch64_sme_cntsd
>();
967 target
.addLegalDialect
<arith::ArithDialect
,
968 /* The following are used to lower tile spills/fills */
969 vector::VectorDialect
, scf::SCFDialect
,
970 memref::MemRefDialect
>();
971 // Pseudo operations. These cannot be code-generated but may exist in the
972 // input IR, or be generated during the conversion. They need to be eliminated
973 // before the final conversion to LLVM IR (and likely will be due to DCE).
974 target
.addLegalOp
<arm_sme::GetTileOp
, arm_sme::CopyTileOp
,
975 UnrealizedConversionCastOp
>();
978 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter
&converter
,
979 RewritePatternSet
&patterns
) {
980 converter
.addConversion([&](VectorType type
) -> std::optional
<Type
> {
981 // There's no LLVM type for SME tiles, but after lowering to intrinsics all
982 // SME vector types should be eliminated.
983 if (arm_sme::isValidSMETileVectorType(type
))
988 addArmSMEConversionPatterns
<
989 LoadTileSliceConversion
, ExtractTileSliceConversion
,
990 InsertTileSliceConversion
, StoreTileSliceConversion
,
991 StreamingVLOpConversion
, OuterProductOpConversion
,
992 OuterProductWideningOpConversion
<arm_sme::FMopa2WayOp
,
993 arm_sme::aarch64_sme_mopa_wide
>,
994 OuterProductWideningOpConversion
<arm_sme::FMops2WayOp
,
995 arm_sme::aarch64_sme_mops_wide
>,
996 OuterProductWideningOpConversion
<arm_sme::SMopa2WayOp
,
997 arm_sme::aarch64_sme_smopa_za32
>,
998 OuterProductWideningOpConversion
<arm_sme::SMops2WayOp
,
999 arm_sme::aarch64_sme_smops_za32
>,
1000 OuterProductWideningOpConversion
<arm_sme::UMopa2WayOp
,
1001 arm_sme::aarch64_sme_umopa_za32
>,
1002 OuterProductWideningOpConversion
<arm_sme::UMops2WayOp
,
1003 arm_sme::aarch64_sme_umops_za32
>,
1004 OuterProductWideningOpConversion
<arm_sme::SMopa4WayOp
,
1005 arm_sme::aarch64_sme_smopa_wide
>,
1006 OuterProductWideningOpConversion
<arm_sme::SMops4WayOp
,
1007 arm_sme::aarch64_sme_smops_wide
>,
1008 OuterProductWideningOpConversion
<arm_sme::UMopa4WayOp
,
1009 arm_sme::aarch64_sme_umopa_wide
>,
1010 OuterProductWideningOpConversion
<arm_sme::UMops4WayOp
,
1011 arm_sme::aarch64_sme_umops_wide
>,
1012 OuterProductWideningOpConversion
<arm_sme::SuMopa4WayOp
,
1013 arm_sme::aarch64_sme_sumopa_wide
>,
1014 OuterProductWideningOpConversion
<arm_sme::SuMops4WayOp
,
1015 arm_sme::aarch64_sme_sumops_wide
>,
1016 OuterProductWideningOpConversion
<arm_sme::UsMopa4WayOp
,
1017 arm_sme::aarch64_sme_usmopa_wide
>,
1018 OuterProductWideningOpConversion
<arm_sme::UsMops4WayOp
,
1019 arm_sme::aarch64_sme_usmops_wide
>,
1020 ZeroOpConversion
>(patterns
, converter
);
1023 std::unique_ptr
<Pass
>
1024 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges
) {
1025 return std::make_unique
<ConvertArmSMEToLLVMPass
>(dumpTileLiveRanges
);