1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Lower matrix intrinsics to vector operations.
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
18 //===----------------------------------------------------------------------===//
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/GraphTraits.h"
22 #include "llvm/ADT/PostOrderIterator.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/DomTreeUpdater.h"
26 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Analysis/VectorUtils.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/DataLayout.h"
32 #include "llvm/IR/DebugInfoMetadata.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/IntrinsicInst.h"
37 #include "llvm/IR/MatrixBuilder.h"
38 #include "llvm/IR/PatternMatch.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Alignment.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Transforms/Scalar.h"
45 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
46 #include "llvm/Transforms/Utils/LoopUtils.h"
47 #include "llvm/Transforms/Utils/MatrixUtils.h"
50 using namespace PatternMatch
;
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden
,
56 cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
58 static cl::opt
<unsigned> TileSize(
59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden
,
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt
<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
64 cl::desc("Generate loop nest for tiling."));
65 static cl::opt
<bool> ForceFusion(
66 "force-fuse-matrix", cl::init(false), cl::Hidden
,
67 cl::desc("Force matrix instruction fusion even if not profitable."));
68 static cl::opt
<bool> AllowContractEnabled(
69 "matrix-allow-contract", cl::init(false), cl::Hidden
,
70 cl::desc("Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
73 enum class MatrixLayoutTy
{ ColumnMajor
, RowMajor
};
75 static cl::opt
<MatrixLayoutTy
> MatrixLayout(
76 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor
),
77 cl::desc("Sets the default matrix layout"),
78 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor
, "column-major",
79 "Use column-major layout"),
80 clEnumValN(MatrixLayoutTy::RowMajor
, "row-major",
81 "Use row-major layout")));
83 /// Helper function to either return Scope, if it is a subprogram or the
84 /// attached subprogram for a local scope.
85 static DISubprogram
*getSubprogram(DIScope
*Scope
) {
86 if (auto *Subprogram
= dyn_cast
<DISubprogram
>(Scope
))
88 return cast
<DILocalScope
>(Scope
)->getSubprogram();
93 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
94 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
95 // assuming \p Stride elements between start two consecutive vectors.
96 // \p Stride must be >= \p NumElements.
97 // For column-major matrixes, the function computes the address of a column
98 // vectors and \p NumElements must be set to the number of elements in a column
99 // (= number of rows of the matrix). For row-major matrixes, the function
100 // computes the address of a row vector and \p NumElements must be set to the
101 // number of elements in a column (= number of columns of the matrix).
103 // Consider a 4x4 matrix in column-mjaor layout like below
106 // 0 v_0_0 v_0_1 v_0_2 v_0_3
107 // 1 v_1_0 v_1_1 v_1_2 v_1_3
108 // 2 v_2_0 v_2_1 v_2_2 v_2_3
109 // 3 v_3_0 v_3_1 v_3_2 v_3_3
111 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
112 // we need a pointer to the first element of the submatrix as base pointer.
113 // Then we can use computeVectorAddr to compute the addresses for the columns
114 // of the sub-matrix.
116 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
117 // -> just returns Base
118 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
119 // -> returns Base + (1 * 4)
120 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
121 // -> returns Base + (2 * 4)
123 // The graphic below illustrates the number of elements in a column (marked
124 // with |) and the number of skipped elements (marked with }).
126 // v_0_0 v_0_1 {v_0_2 {v_0_3
129 // v_1_0 |v_1_1 |v_1_2 |v_1_3
130 // v_2_0 |v_2_1 |v_2_2 |v_2_3
131 // v_3_0 {v_3_1 {v_3_2 v_3_3
133 Value
*computeVectorAddr(Value
*BasePtr
, Value
*VecIdx
, Value
*Stride
,
134 unsigned NumElements
, Type
*EltType
,
135 IRBuilder
<> &Builder
) {
137 assert((!isa
<ConstantInt
>(Stride
) ||
138 cast
<ConstantInt
>(Stride
)->getZExtValue() >= NumElements
) &&
139 "Stride must be >= the number of elements in the result vector.");
140 unsigned AS
= cast
<PointerType
>(BasePtr
->getType())->getAddressSpace();
142 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
143 Value
*VecStart
= Builder
.CreateMul(VecIdx
, Stride
, "vec.start");
145 // Get pointer to the start of the selected vector. Skip GEP creation,
146 // if we select vector 0.
147 if (isa
<ConstantInt
>(VecStart
) && cast
<ConstantInt
>(VecStart
)->isZero())
150 VecStart
= Builder
.CreateGEP(EltType
, BasePtr
, VecStart
, "vec.gep");
152 // Cast elementwise vector start pointer to a pointer to a vector
153 // (EltType x NumElements)*.
154 auto *VecType
= FixedVectorType::get(EltType
, NumElements
);
155 Type
*VecPtrType
= PointerType::get(VecType
, AS
);
156 return Builder
.CreatePointerCast(VecStart
, VecPtrType
, "vec.cast");
159 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
161 /// Currently, the lowering for each matrix intrinsic is done as follows:
162 /// 1. Propagate the shape information from intrinsics to connected
164 /// 2. Lower instructions with shape information (assuming column-major layout).
165 /// The lowering works similarly using row-major layout.
166 /// 2.1. Get column vectors for each argument. If we already lowered the
167 /// definition of an argument, use the produced column vectors directly.
168 /// If not, split the operand vector containing an embedded matrix into
169 /// a set of column vectors,
170 /// 2.2. Lower the instruction in terms of column major operations, which
171 /// yields a set of column vectors containing result matrix. Note that we
172 /// lower all instructions that have shape information. Besides the
173 /// intrinsics, this includes stores for example.
174 /// 2.3. Update uses of the lowered instruction. If we have shape information
175 /// for a user, there is nothing to do, as we will look up the result
176 /// column matrix when lowering the user. For other uses, we embed the
177 /// result matrix in a flat vector and update the use.
178 /// 2.4. Cache the result column matrix for the instruction we lowered
179 /// 3. After we lowered all instructions in a function, remove the now
180 /// obsolete instructions.
182 class LowerMatrixIntrinsics
{
184 const DataLayout
&DL
;
185 const TargetTransformInfo
&TTI
;
189 OptimizationRemarkEmitter
*ORE
;
191 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
193 /// Number of stores emitted to generate this matrix.
194 unsigned NumStores
= 0;
195 /// Number of loads emitted to generate this matrix.
196 unsigned NumLoads
= 0;
197 /// Number of compute operations emitted to generate this matrix.
198 unsigned NumComputeOps
= 0;
199 /// Most of the time transposes can be fused with matrix multiplies or can
200 /// be folded away via algebraic simplifications. This is the number of
201 /// transposes that we failed to make "free" via such optimizations.
202 unsigned NumExposedTransposes
= 0;
204 OpInfoTy
&operator+=(const OpInfoTy
&RHS
) {
205 NumStores
+= RHS
.NumStores
;
206 NumLoads
+= RHS
.NumLoads
;
207 NumComputeOps
+= RHS
.NumComputeOps
;
208 NumExposedTransposes
+= RHS
.NumExposedTransposes
;
213 /// Wrapper class representing a matrix as a set of vectors, either in row or
214 /// column major layout. All vectors must have the same vector type.
216 SmallVector
<Value
*, 16> Vectors
;
220 bool IsColumnMajor
= true;
225 IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
226 MatrixTy(ArrayRef
<Value
*> Vectors
)
227 : Vectors(Vectors
.begin(), Vectors
.end()),
228 IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
229 MatrixTy(unsigned NumRows
, unsigned NumColumns
, Type
*EltTy
)
230 : IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {
232 unsigned D
= isColumnMajor() ? NumColumns
: NumRows
;
233 for (unsigned J
= 0; J
< D
; ++J
)
234 addVector(UndefValue::get(FixedVectorType::get(
235 EltTy
, isColumnMajor() ? NumRows
: NumColumns
)));
238 Value
*getVector(unsigned i
) const { return Vectors
[i
]; }
239 Value
*getColumn(unsigned i
) const {
240 assert(isColumnMajor() && "only supported for column-major matrixes");
243 Value
*getRow(unsigned i
) const {
244 assert(!isColumnMajor() && "only supported for row-major matrixes");
248 void setVector(unsigned i
, Value
*V
) { Vectors
[i
] = V
; }
250 Type
*getElementType() const { return getVectorTy()->getElementType(); }
252 unsigned getNumVectors() const {
254 return getNumColumns();
258 unsigned getNumColumns() const {
260 return Vectors
.size();
262 assert(Vectors
.size() > 0 && "Cannot call getNumRows without columns");
263 return cast
<FixedVectorType
>(Vectors
[0]->getType())->getNumElements();
266 unsigned getNumRows() const {
267 if (isColumnMajor()) {
268 assert(Vectors
.size() > 0 && "Cannot call getNumRows without columns");
269 return cast
<FixedVectorType
>(Vectors
[0]->getType())->getNumElements();
271 return Vectors
.size();
274 void addVector(Value
*V
) { Vectors
.push_back(V
); }
275 VectorType
*getColumnTy() {
276 assert(isColumnMajor() && "only supported for column-major matrixes");
277 return getVectorTy();
280 VectorType
*getVectorTy() const {
281 return cast
<VectorType
>(Vectors
[0]->getType());
284 iterator_range
<SmallVector
<Value
*, 8>::iterator
> columns() {
285 assert(isColumnMajor() &&
286 "columns() only supported for column-major matrixes");
287 return make_range(Vectors
.begin(), Vectors
.end());
290 iterator_range
<SmallVector
<Value
*, 8>::iterator
> vectors() {
291 return make_range(Vectors
.begin(), Vectors
.end());
294 /// Embed the vectors of the matrix into a flat vector by concatenating
296 Value
*embedInVector(IRBuilder
<> &Builder
) const {
297 return Vectors
.size() == 1 ? Vectors
[0]
298 : concatenateVectors(Builder
, Vectors
);
301 MatrixTy
&addNumLoads(unsigned N
) {
302 OpInfo
.NumLoads
+= N
;
306 void setNumLoads(unsigned N
) { OpInfo
.NumLoads
= N
; }
308 MatrixTy
&addNumStores(unsigned N
) {
309 OpInfo
.NumStores
+= N
;
313 MatrixTy
&addNumExposedTransposes(unsigned N
) {
314 OpInfo
.NumExposedTransposes
+= N
;
318 MatrixTy
&addNumComputeOps(unsigned N
) {
319 OpInfo
.NumComputeOps
+= N
;
323 unsigned getNumStores() const { return OpInfo
.NumStores
; }
324 unsigned getNumLoads() const { return OpInfo
.NumLoads
; }
325 unsigned getNumComputeOps() const { return OpInfo
.NumComputeOps
; }
327 const OpInfoTy
&getOpInfo() const { return OpInfo
; }
329 bool isColumnMajor() const { return IsColumnMajor
; }
331 unsigned getStride() const {
334 return getNumColumns();
337 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
338 /// matrix is column-major, the result vector is extracted from a column
339 /// vector, otherwise from a row vector.
340 Value
*extractVector(unsigned I
, unsigned J
, unsigned NumElts
,
341 IRBuilder
<> &Builder
) const {
342 Value
*Vec
= isColumnMajor() ? getColumn(J
) : getRow(I
);
343 return Builder
.CreateShuffleVector(
344 Vec
, createSequentialMask(isColumnMajor() ? I
: J
, NumElts
, 0),
355 ShapeInfo(unsigned NumRows
= 0, unsigned NumColumns
= 0)
356 : NumRows(NumRows
), NumColumns(NumColumns
),
357 IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
359 ShapeInfo(Value
*NumRows
, Value
*NumColumns
)
360 : ShapeInfo(cast
<ConstantInt
>(NumRows
)->getZExtValue(),
361 cast
<ConstantInt
>(NumColumns
)->getZExtValue()) {}
363 bool operator==(const ShapeInfo
&other
) {
364 return NumRows
== other
.NumRows
&& NumColumns
== other
.NumColumns
;
366 bool operator!=(const ShapeInfo
&other
) { return !(*this == other
); }
368 /// Returns true if shape-information is defined, meaning both dimensions
370 operator bool() const {
371 assert(NumRows
== 0 || NumColumns
!= 0);
375 unsigned getStride() const {
381 unsigned getNumVectors() const {
388 /// Maps instructions to their shape information. The shape information
389 /// describes the shape to be used while lowering. This matches the shape of
390 /// the result value of the instruction, with the only exceptions being store
391 /// instructions and the matrix_column_major_store intrinsics. For those, the
392 /// shape information indicates that those instructions should be lowered
393 /// using shape information as well. A ValueMap is used so that when
394 /// sub-passes like optimizeTransposes performs RAUW the map stays
396 ValueMap
<Value
*, ShapeInfo
> ShapeMap
;
398 /// List of instructions to remove. While lowering, we are not replacing all
399 /// users of a lowered instruction, if shape information is available and
400 /// those need to be removed after we finished lowering.
401 SmallVector
<Instruction
*, 16> ToRemove
;
403 /// Map from instructions to their produced column matrix.
404 MapVector
<Value
*, MatrixTy
> Inst2ColumnMatrix
;
407 static FastMathFlags
getFastMathFlags(Instruction
*Inst
) {
410 if (isa
<FPMathOperator
>(*Inst
))
411 FMF
= Inst
->getFastMathFlags();
413 FMF
.setAllowContract(AllowContractEnabled
|| FMF
.allowContract());
419 LowerMatrixIntrinsics(Function
&F
, TargetTransformInfo
&TTI
,
420 AliasAnalysis
*AA
, DominatorTree
*DT
, LoopInfo
*LI
,
421 OptimizationRemarkEmitter
*ORE
)
422 : Func(F
), DL(F
.getParent()->getDataLayout()), TTI(TTI
), AA(AA
), DT(DT
),
425 unsigned getNumOps(Type
*VT
) {
426 assert(isa
<VectorType
>(VT
) && "Expected vector type");
427 return getNumOps(VT
->getScalarType(),
428 cast
<FixedVectorType
>(VT
)->getNumElements());
431 /// Is this the minimal version executed in the backend pipelines.
432 bool isMinimal() const {
436 /// Return the estimated number of vector ops required for an operation on
438 unsigned getNumOps(Type
*ST
, unsigned N
) {
439 return std::ceil((ST
->getPrimitiveSizeInBits() * N
).getFixedSize() /
440 double(TTI
.getRegisterBitWidth(
441 TargetTransformInfo::RGK_FixedWidthVector
)
445 /// Return the set of vectors that a matrix value is lowered to.
447 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
448 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
450 MatrixTy
getMatrix(Value
*MatrixVal
, const ShapeInfo
&SI
,
451 IRBuilder
<> &Builder
) {
452 VectorType
*VType
= dyn_cast
<VectorType
>(MatrixVal
->getType());
453 assert(VType
&& "MatrixVal must be a vector type");
454 assert(cast
<FixedVectorType
>(VType
)->getNumElements() ==
455 SI
.NumRows
* SI
.NumColumns
&&
456 "The vector size must match the number of matrix elements");
458 // Check if we lowered MatrixVal using shape information. In that case,
459 // return the existing matrix, if it matches the requested shape
460 // information. If there is a mis-match, embed the result in a flat
461 // vector and split it later.
462 auto Found
= Inst2ColumnMatrix
.find(MatrixVal
);
463 if (Found
!= Inst2ColumnMatrix
.end()) {
464 MatrixTy
&M
= Found
->second
;
465 // Return the found matrix, if its shape matches the requested shape
467 if (SI
.NumRows
== M
.getNumRows() && SI
.NumColumns
== M
.getNumColumns())
470 MatrixVal
= M
.embedInVector(Builder
);
473 // Otherwise split MatrixVal.
474 SmallVector
<Value
*, 16> SplitVecs
;
475 for (unsigned MaskStart
= 0;
476 MaskStart
< cast
<FixedVectorType
>(VType
)->getNumElements();
477 MaskStart
+= SI
.getStride()) {
478 Value
*V
= Builder
.CreateShuffleVector(
479 MatrixVal
, createSequentialMask(MaskStart
, SI
.getStride(), 0),
481 SplitVecs
.push_back(V
);
487 /// If \p V already has a known shape return false. Otherwise set the shape
488 /// for instructions that support it.
489 bool setShapeInfo(Value
*V
, ShapeInfo Shape
) {
490 assert(Shape
&& "Shape not set");
491 if (isa
<UndefValue
>(V
) || !supportsShapeInfo(V
))
494 auto SIter
= ShapeMap
.find(V
);
495 if (SIter
!= ShapeMap
.end()) {
496 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
497 << SIter
->second
.NumRows
<< " "
498 << SIter
->second
.NumColumns
<< " for " << *V
<< "\n");
502 ShapeMap
.insert({V
, Shape
});
503 LLVM_DEBUG(dbgs() << " " << Shape
.NumRows
<< " x " << Shape
.NumColumns
504 << " for " << *V
<< "\n");
508 bool isUniformShape(Value
*V
) {
509 Instruction
*I
= dyn_cast
<Instruction
>(V
);
513 switch (I
->getOpcode()) {
514 case Instruction::FAdd
:
515 case Instruction::FSub
:
516 case Instruction::FMul
: // Scalar multiply.
517 case Instruction::FNeg
:
518 case Instruction::Add
:
519 case Instruction::Mul
:
520 case Instruction::Sub
:
527 /// Returns true if shape information can be used for \p V. The supported
528 /// instructions must match the instructions that can be lowered by this pass.
529 bool supportsShapeInfo(Value
*V
) {
530 Instruction
*Inst
= dyn_cast
<Instruction
>(V
);
534 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(Inst
);
536 switch (II
->getIntrinsicID()) {
537 case Intrinsic::matrix_multiply
:
538 case Intrinsic::matrix_transpose
:
539 case Intrinsic::matrix_column_major_load
:
540 case Intrinsic::matrix_column_major_store
:
545 return isUniformShape(V
) || isa
<StoreInst
>(V
) || isa
<LoadInst
>(V
);
548 /// Propagate the shape information of instructions to their users.
549 /// The work list contains instructions for which we can compute the shape,
550 /// either based on the information provided by matrix intrinsics or known
551 /// shapes of operands.
552 SmallVector
<Instruction
*, 32>
553 propagateShapeForward(SmallVectorImpl
<Instruction
*> &WorkList
) {
554 SmallVector
<Instruction
*, 32> NewWorkList
;
555 // Pop an element for which we guaranteed to have at least one of the
556 // operand shapes. Add the shape for this and then add users to the work
558 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
559 while (!WorkList
.empty()) {
560 Instruction
*Inst
= WorkList
.pop_back_val();
562 // New entry, set the value and insert operands
563 bool Propagate
= false;
570 if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
571 m_Value(MatrixA
), m_Value(MatrixB
), m_Value(M
),
572 m_Value(N
), m_Value(K
)))) {
573 Propagate
= setShapeInfo(Inst
, {M
, K
});
574 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
575 m_Value(MatrixA
), m_Value(M
), m_Value(N
)))) {
577 Propagate
= setShapeInfo(Inst
, {N
, M
});
578 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_column_major_store
>(
579 m_Value(MatrixA
), m_Value(), m_Value(),
580 m_Value(), m_Value(M
), m_Value(N
)))) {
581 Propagate
= setShapeInfo(Inst
, {N
, M
});
582 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>(
583 m_Value(), m_Value(), m_Value(), m_Value(M
),
585 Propagate
= setShapeInfo(Inst
, {M
, N
});
586 } else if (match(Inst
, m_Store(m_Value(MatrixA
), m_Value()))) {
587 auto OpShape
= ShapeMap
.find(MatrixA
);
588 if (OpShape
!= ShapeMap
.end())
589 setShapeInfo(Inst
, OpShape
->second
);
591 } else if (isUniformShape(Inst
)) {
592 // Find the first operand that has a known shape and use that.
593 for (auto &Op
: Inst
->operands()) {
594 auto OpShape
= ShapeMap
.find(Op
.get());
595 if (OpShape
!= ShapeMap
.end()) {
596 Propagate
|= setShapeInfo(Inst
, OpShape
->second
);
603 NewWorkList
.push_back(Inst
);
604 for (auto *User
: Inst
->users())
605 if (ShapeMap
.count(User
) == 0)
606 WorkList
.push_back(cast
<Instruction
>(User
));
613 /// Propagate the shape to operands of instructions with shape information.
614 /// \p Worklist contains the instruction for which we already know the shape.
615 SmallVector
<Instruction
*, 32>
616 propagateShapeBackward(SmallVectorImpl
<Instruction
*> &WorkList
) {
617 SmallVector
<Instruction
*, 32> NewWorkList
;
619 auto pushInstruction
= [](Value
*V
,
620 SmallVectorImpl
<Instruction
*> &WorkList
) {
621 Instruction
*I
= dyn_cast
<Instruction
>(V
);
623 WorkList
.push_back(I
);
625 // Pop an element with known shape. Traverse the operands, if their shape
626 // derives from the result shape and is unknown, add it and add them to the
628 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
629 while (!WorkList
.empty()) {
630 Value
*V
= WorkList
.pop_back_val();
632 size_t BeforeProcessingV
= WorkList
.size();
633 if (!isa
<Instruction
>(V
))
641 if (match(V
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
642 m_Value(MatrixA
), m_Value(MatrixB
), m_Value(M
),
643 m_Value(N
), m_Value(K
)))) {
644 if (setShapeInfo(MatrixA
, {M
, N
}))
645 pushInstruction(MatrixA
, WorkList
);
647 if (setShapeInfo(MatrixB
, {N
, K
}))
648 pushInstruction(MatrixB
, WorkList
);
650 } else if (match(V
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
651 m_Value(MatrixA
), m_Value(M
), m_Value(N
)))) {
653 if (setShapeInfo(MatrixA
, {M
, N
}))
654 pushInstruction(MatrixA
, WorkList
);
655 } else if (match(V
, m_Intrinsic
<Intrinsic::matrix_column_major_store
>(
656 m_Value(MatrixA
), m_Value(), m_Value(), m_Value(),
657 m_Value(M
), m_Value(N
)))) {
658 if (setShapeInfo(MatrixA
, {M
, N
})) {
659 pushInstruction(MatrixA
, WorkList
);
661 } else if (isa
<LoadInst
>(V
) ||
662 match(V
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>())) {
663 // Nothing to do, no matrix input.
664 } else if (isa
<StoreInst
>(V
)) {
665 // Nothing to do. We forward-propagated to this so we would just
666 // backward propagate to an instruction with an already known shape.
667 } else if (isUniformShape(V
)) {
668 // Propagate to all operands.
669 ShapeInfo Shape
= ShapeMap
[V
];
670 for (Use
&U
: cast
<Instruction
>(V
)->operands()) {
671 if (setShapeInfo(U
.get(), Shape
))
672 pushInstruction(U
.get(), WorkList
);
675 // After we discovered new shape info for new instructions in the
676 // worklist, we use their users as seeds for the next round of forward
678 for (size_t I
= BeforeProcessingV
; I
!= WorkList
.size(); I
++)
679 for (User
*U
: WorkList
[I
]->users())
680 if (isa
<Instruction
>(U
) && V
!= U
)
681 NewWorkList
.push_back(cast
<Instruction
>(U
));
686 /// Try moving transposes in order to fold them away or into multiplies.
687 void optimizeTransposes() {
688 auto ReplaceAllUsesWith
= [this](Instruction
&Old
, Value
*New
) {
689 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
690 // with New. We should only add New it it supportsShapeInfo so we insert
691 // it conditionally instead.
692 auto S
= ShapeMap
.find(&Old
);
693 if (S
!= ShapeMap
.end()) {
695 if (supportsShapeInfo(New
))
696 ShapeMap
.insert({New
, S
->second
});
698 Old
.replaceAllUsesWith(New
);
701 // First sink all transposes inside matmuls, hoping that we end up with NN,
702 // NT or TN variants.
703 for (BasicBlock
&BB
: reverse(Func
)) {
704 for (auto II
= BB
.rbegin(); II
!= BB
.rend();) {
705 Instruction
&I
= *II
;
706 // We may remove II. By default continue on the next/prev instruction.
708 // If we were to erase II, move again.
709 auto EraseFromParent
= [&II
](Value
*V
) {
710 auto *Inst
= cast
<Instruction
>(V
);
711 if (Inst
->use_empty()) {
715 Inst
->eraseFromParent();
719 // If we're creating a new instruction, continue from there.
720 Instruction
*NewInst
= nullptr;
723 MatrixBuilder
<IRBuilder
<>> Builder(IB
);
725 Value
*TA
, *TAMA
, *TAMB
;
726 ConstantInt
*R
, *K
, *C
;
727 if (match(&I
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(TA
)))) {
729 // Transpose of a transpose is a nop
732 m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(TATA
)))) {
733 ReplaceAllUsesWith(I
, TATA
);
738 // (A * B)^t -> B^t * A^t
740 else if (match(TA
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
741 m_Value(TAMA
), m_Value(TAMB
), m_ConstantInt(R
),
742 m_ConstantInt(K
), m_ConstantInt(C
)))) {
743 Value
*T0
= Builder
.CreateMatrixTranspose(TAMB
, K
->getZExtValue(),
745 TAMB
->getName() + "_t");
746 // We are being run after shape prop, add shape for newly created
747 // instructions so that we lower them later.
748 setShapeInfo(T0
, {C
, K
});
749 Value
*T1
= Builder
.CreateMatrixTranspose(TAMA
, R
->getZExtValue(),
751 TAMA
->getName() + "_t");
752 setShapeInfo(T1
, {K
, R
});
753 NewInst
= Builder
.CreateMatrixMultiply(T0
, T1
, C
->getZExtValue(),
755 R
->getZExtValue(), "mmul");
756 ReplaceAllUsesWith(I
, NewInst
);
762 // If we replaced I with a new instruction, continue from there.
764 II
= std::next(BasicBlock::reverse_iterator(NewInst
));
768 // If we have a TT matmul, lift the transpose. We may be able to fold into
769 // consuming multiply.
770 for (BasicBlock
&BB
: Func
) {
771 for (BasicBlock::iterator II
= BB
.begin(); II
!= BB
.end();) {
772 Instruction
*I
= &*II
;
775 Value
*A
, *B
, *AT
, *BT
;
776 ConstantInt
*R
, *K
, *C
;
777 // A^t * B ^t -> (B * A)^t
778 if (match(&*I
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
779 m_Value(A
), m_Value(B
), m_ConstantInt(R
),
780 m_ConstantInt(K
), m_ConstantInt(C
))) &&
781 match(A
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(AT
))) &&
782 match(B
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value((BT
))))) {
784 MatrixBuilder
<IRBuilder
<>> Builder(IB
);
785 Value
*M
= Builder
.CreateMatrixMultiply(
786 BT
, AT
, C
->getZExtValue(), K
->getZExtValue(), R
->getZExtValue());
787 setShapeInfo(M
, {C
, R
});
788 Instruction
*NewInst
= Builder
.CreateMatrixTranspose(
789 M
, C
->getZExtValue(), R
->getZExtValue());
790 ReplaceAllUsesWith(*I
, NewInst
);
792 I
->eraseFromParent();
794 cast
<Instruction
>(A
)->eraseFromParent();
795 if (A
!= B
&& B
->use_empty())
796 cast
<Instruction
>(B
)->eraseFromParent();
803 SmallVector
<Instruction
*, 32> WorkList
;
805 // Initially only the shape of matrix intrinsics is known.
806 // Initialize the work list with ops carrying shape information.
807 for (BasicBlock
&BB
: Func
)
808 for (Instruction
&Inst
: BB
) {
809 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&Inst
);
813 switch (II
->getIntrinsicID()) {
814 case Intrinsic::matrix_multiply
:
815 case Intrinsic::matrix_transpose
:
816 case Intrinsic::matrix_column_major_load
:
817 case Intrinsic::matrix_column_major_store
:
818 WorkList
.push_back(&Inst
);
825 // Avoid unnecessary work if there are no matrix intrinsics in the function.
826 if (WorkList
.empty())
829 // Propagate shapes until nothing changes any longer.
830 while (!WorkList
.empty()) {
831 WorkList
= propagateShapeForward(WorkList
);
832 WorkList
= propagateShapeBackward(WorkList
);
836 optimizeTransposes();
838 dbgs() << "Dump after matrix transpose optimization:\n";
843 bool Changed
= false;
844 SmallVector
<CallInst
*, 16> MaybeFusableInsts
;
845 SmallVector
<Instruction
*, 16> MatrixInsts
;
847 // First, collect all instructions with shape information and candidates for
848 // fusion (currently only matrix multiplies).
849 ReversePostOrderTraversal
<Function
*> RPOT(&Func
);
850 for (auto *BB
: RPOT
)
851 for (Instruction
&I
: *BB
) {
852 if (ShapeMap
.find(&I
) == ShapeMap
.end())
854 if (match(&I
, m_Intrinsic
<Intrinsic::matrix_multiply
>()))
855 MaybeFusableInsts
.push_back(cast
<CallInst
>(&I
));
856 MatrixInsts
.push_back(&I
);
859 // Second, try to fuse candidates.
860 SmallPtrSet
<Instruction
*, 16> FusedInsts
;
861 for (CallInst
*CI
: MaybeFusableInsts
)
862 LowerMatrixMultiplyFused(CI
, FusedInsts
);
863 Changed
= !FusedInsts
.empty();
865 // Third, lower remaining instructions with shape information.
866 for (Instruction
*Inst
: MatrixInsts
) {
867 if (FusedInsts
.count(Inst
))
870 IRBuilder
<> Builder(Inst
);
872 if (CallInst
*CInst
= dyn_cast
<CallInst
>(Inst
))
873 Changed
|= VisitCallInst(CInst
);
877 if (auto *BinOp
= dyn_cast
<BinaryOperator
>(Inst
))
878 Changed
|= VisitBinaryOperator(BinOp
);
879 if (auto *UnOp
= dyn_cast
<UnaryOperator
>(Inst
))
880 Changed
|= VisitUnaryOperator(UnOp
);
881 if (match(Inst
, m_Load(m_Value(Op1
))))
882 Changed
|= VisitLoad(cast
<LoadInst
>(Inst
), Op1
, Builder
);
883 else if (match(Inst
, m_Store(m_Value(Op1
), m_Value(Op2
))))
884 Changed
|= VisitStore(cast
<StoreInst
>(Inst
), Op1
, Op2
, Builder
);
888 RemarkGenerator
RemarkGen(Inst2ColumnMatrix
, *ORE
, Func
);
889 RemarkGen
.emitRemarks();
892 // Delete the instructions backwards, as it has a reduced likelihood of
893 // having to update as many def-use and use-def chains.
895 // Because we add to ToRemove during fusion we can't guarantee that defs
896 // are before uses. Change uses to undef temporarily as these should get
899 // For verification, we keep track of where we changed uses to undefs in
900 // UndefedInsts and then check that we in fact remove them.
901 SmallSet
<Instruction
*, 16> UndefedInsts
;
902 for (auto *Inst
: reverse(ToRemove
)) {
903 for (auto I
= Inst
->use_begin(), E
= Inst
->use_end(); I
!= E
;) {
905 if (auto *Undefed
= dyn_cast
<Instruction
>(U
.getUser()))
906 UndefedInsts
.insert(Undefed
);
907 U
.set(UndefValue::get(Inst
->getType()));
909 Inst
->eraseFromParent();
910 UndefedInsts
.erase(Inst
);
912 if (!UndefedInsts
.empty()) {
913 // If we didn't remove all undefed instructions, it's a hard error.
914 dbgs() << "Undefed but present instructions:\n";
915 for (auto *I
: UndefedInsts
)
916 dbgs() << *I
<< "\n";
917 llvm_unreachable("Undefed but instruction not removed");
923 /// Turns \p BasePtr into an elementwise pointer to \p EltType.
924 Value
*createElementPtr(Value
*BasePtr
, Type
*EltType
, IRBuilder
<> &Builder
) {
925 unsigned AS
= cast
<PointerType
>(BasePtr
->getType())->getAddressSpace();
926 Type
*EltPtrType
= PointerType::get(EltType
, AS
);
927 return Builder
.CreatePointerCast(BasePtr
, EltPtrType
);
930 /// Replace intrinsic calls
931 bool VisitCallInst(CallInst
*Inst
) {
932 if (!Inst
->getCalledFunction() || !Inst
->getCalledFunction()->isIntrinsic())
935 switch (Inst
->getCalledFunction()->getIntrinsicID()) {
936 case Intrinsic::matrix_multiply
:
939 case Intrinsic::matrix_transpose
:
940 LowerTranspose(Inst
);
942 case Intrinsic::matrix_column_major_load
:
943 LowerColumnMajorLoad(Inst
);
945 case Intrinsic::matrix_column_major_store
:
946 LowerColumnMajorStore(Inst
);
954 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
955 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
956 /// ConstantInt, reduce the initial alignment based on the byte offset. For
957 /// non-ConstantInt strides, return the common alignment of the initial
958 /// alignment and the element size in bytes.
959 Align
getAlignForIndex(unsigned Idx
, Value
*Stride
, Type
*ElementTy
,
960 MaybeAlign A
) const {
961 Align InitialAlign
= DL
.getValueOrABITypeAlignment(A
, ElementTy
);
965 TypeSize ElementSizeInBits
= DL
.getTypeSizeInBits(ElementTy
);
966 if (auto *ConstStride
= dyn_cast
<ConstantInt
>(Stride
)) {
967 uint64_t StrideInBytes
=
968 ConstStride
->getZExtValue() * ElementSizeInBits
/ 8;
969 return commonAlignment(InitialAlign
, Idx
* StrideInBytes
);
971 return commonAlignment(InitialAlign
, ElementSizeInBits
/ 8);
974 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
976 MatrixTy
loadMatrix(Type
*Ty
, Value
*Ptr
, MaybeAlign MAlign
, Value
*Stride
,
977 bool IsVolatile
, ShapeInfo Shape
, IRBuilder
<> &Builder
) {
978 auto *VType
= cast
<VectorType
>(Ty
);
979 Type
*EltTy
= VType
->getElementType();
980 Type
*VecTy
= FixedVectorType::get(EltTy
, Shape
.getStride());
981 Value
*EltPtr
= createElementPtr(Ptr
, EltTy
, Builder
);
983 for (unsigned I
= 0, E
= Shape
.getNumVectors(); I
< E
; ++I
) {
984 Value
*GEP
= computeVectorAddr(
985 EltPtr
, Builder
.getIntN(Stride
->getType()->getScalarSizeInBits(), I
),
986 Stride
, Shape
.getStride(), EltTy
, Builder
);
987 Value
*Vector
= Builder
.CreateAlignedLoad(
988 VecTy
, GEP
, getAlignForIndex(I
, Stride
, EltTy
, MAlign
),
989 IsVolatile
, "col.load");
991 Result
.addVector(Vector
);
993 return Result
.addNumLoads(getNumOps(Result
.getVectorTy()) *
994 Result
.getNumVectors());
997 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
998 /// starting at \p MatrixPtr[I][J].
999 MatrixTy
loadMatrix(Value
*MatrixPtr
, MaybeAlign Align
, bool IsVolatile
,
1000 ShapeInfo MatrixShape
, Value
*I
, Value
*J
,
1001 ShapeInfo ResultShape
, Type
*EltTy
,
1002 IRBuilder
<> &Builder
) {
1004 Value
*Offset
= Builder
.CreateAdd(
1005 Builder
.CreateMul(J
, Builder
.getInt64(MatrixShape
.getStride())), I
);
1007 unsigned AS
= cast
<PointerType
>(MatrixPtr
->getType())->getAddressSpace();
1009 Builder
.CreatePointerCast(MatrixPtr
, PointerType::get(EltTy
, AS
));
1010 Value
*TileStart
= Builder
.CreateGEP(EltTy
, EltPtr
, Offset
);
1011 auto *TileTy
= FixedVectorType::get(EltTy
, ResultShape
.NumRows
*
1012 ResultShape
.NumColumns
);
1013 Type
*TilePtrTy
= PointerType::get(TileTy
, AS
);
1015 Builder
.CreatePointerCast(TileStart
, TilePtrTy
, "col.cast");
1017 return loadMatrix(TileTy
, TilePtr
, Align
,
1018 Builder
.getInt64(MatrixShape
.getStride()), IsVolatile
,
1019 ResultShape
, Builder
);
1022 /// Lower a load instruction with shape information.
1023 void LowerLoad(Instruction
*Inst
, Value
*Ptr
, MaybeAlign Align
, Value
*Stride
,
1024 bool IsVolatile
, ShapeInfo Shape
) {
1025 IRBuilder
<> Builder(Inst
);
1026 finalizeLowering(Inst
,
1027 loadMatrix(Inst
->getType(), Ptr
, Align
, Stride
, IsVolatile
,
1032 /// Lowers llvm.matrix.column.major.load.
1034 /// The intrinsic loads a matrix from memory using a stride between columns.
1035 void LowerColumnMajorLoad(CallInst
*Inst
) {
1036 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1037 "Intrinsic only supports column-major layout!");
1038 Value
*Ptr
= Inst
->getArgOperand(0);
1039 Value
*Stride
= Inst
->getArgOperand(1);
1040 LowerLoad(Inst
, Ptr
, Inst
->getParamAlign(0), Stride
,
1041 cast
<ConstantInt
>(Inst
->getArgOperand(2))->isOne(),
1042 {Inst
->getArgOperand(3), Inst
->getArgOperand(4)});
1045 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1046 /// MatrixPtr[I][J].
1047 void storeMatrix(const MatrixTy
&StoreVal
, Value
*MatrixPtr
,
1048 MaybeAlign MAlign
, bool IsVolatile
, ShapeInfo MatrixShape
,
1049 Value
*I
, Value
*J
, Type
*EltTy
, IRBuilder
<> &Builder
) {
1050 Value
*Offset
= Builder
.CreateAdd(
1051 Builder
.CreateMul(J
, Builder
.getInt64(MatrixShape
.getStride())), I
);
1053 unsigned AS
= cast
<PointerType
>(MatrixPtr
->getType())->getAddressSpace();
1055 Builder
.CreatePointerCast(MatrixPtr
, PointerType::get(EltTy
, AS
));
1056 Value
*TileStart
= Builder
.CreateGEP(EltTy
, EltPtr
, Offset
);
1057 auto *TileTy
= FixedVectorType::get(EltTy
, StoreVal
.getNumRows() *
1058 StoreVal
.getNumColumns());
1059 Type
*TilePtrTy
= PointerType::get(TileTy
, AS
);
1061 Builder
.CreatePointerCast(TileStart
, TilePtrTy
, "col.cast");
1063 storeMatrix(TileTy
, StoreVal
, TilePtr
, MAlign
,
1064 Builder
.getInt64(MatrixShape
.getStride()), IsVolatile
, Builder
);
1067 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1069 MatrixTy
storeMatrix(Type
*Ty
, MatrixTy StoreVal
, Value
*Ptr
,
1070 MaybeAlign MAlign
, Value
*Stride
, bool IsVolatile
,
1071 IRBuilder
<> &Builder
) {
1072 auto VType
= cast
<VectorType
>(Ty
);
1073 Value
*EltPtr
= createElementPtr(Ptr
, VType
->getElementType(), Builder
);
1074 for (auto Vec
: enumerate(StoreVal
.vectors())) {
1075 Value
*GEP
= computeVectorAddr(
1077 Builder
.getIntN(Stride
->getType()->getScalarSizeInBits(),
1079 Stride
, StoreVal
.getStride(), VType
->getElementType(), Builder
);
1080 Builder
.CreateAlignedStore(Vec
.value(), GEP
,
1081 getAlignForIndex(Vec
.index(), Stride
,
1082 VType
->getElementType(),
1086 return MatrixTy().addNumStores(getNumOps(StoreVal
.getVectorTy()) *
1087 StoreVal
.getNumVectors());
1090 /// Lower a store instruction with shape information.
1091 void LowerStore(Instruction
*Inst
, Value
*Matrix
, Value
*Ptr
, MaybeAlign A
,
1092 Value
*Stride
, bool IsVolatile
, ShapeInfo Shape
) {
1093 IRBuilder
<> Builder(Inst
);
1094 auto StoreVal
= getMatrix(Matrix
, Shape
, Builder
);
1095 finalizeLowering(Inst
,
1096 storeMatrix(Matrix
->getType(), StoreVal
, Ptr
, A
, Stride
,
1097 IsVolatile
, Builder
),
1101 /// Lowers llvm.matrix.column.major.store.
1103 /// The intrinsic store a matrix back memory using a stride between columns.
1104 void LowerColumnMajorStore(CallInst
*Inst
) {
1105 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1106 "Intrinsic only supports column-major layout!");
1107 Value
*Matrix
= Inst
->getArgOperand(0);
1108 Value
*Ptr
= Inst
->getArgOperand(1);
1109 Value
*Stride
= Inst
->getArgOperand(2);
1110 LowerStore(Inst
, Matrix
, Ptr
, Inst
->getParamAlign(1), Stride
,
1111 cast
<ConstantInt
>(Inst
->getArgOperand(3))->isOne(),
1112 {Inst
->getArgOperand(4), Inst
->getArgOperand(5)});
1115 // Set elements I..I+NumElts-1 to Block
1116 Value
*insertVector(Value
*Col
, unsigned I
, Value
*Block
,
1117 IRBuilder
<> &Builder
) {
1119 // First, bring Block to the same size as Col
1120 unsigned BlockNumElts
=
1121 cast
<FixedVectorType
>(Block
->getType())->getNumElements();
1122 unsigned NumElts
= cast
<FixedVectorType
>(Col
->getType())->getNumElements();
1123 assert(NumElts
>= BlockNumElts
&& "Too few elements for current block");
1125 Block
= Builder
.CreateShuffleVector(
1126 Block
, createSequentialMask(0, BlockNumElts
, NumElts
- BlockNumElts
));
1128 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1130 SmallVector
<int, 16> Mask
;
1132 for (i
= 0; i
< I
; i
++)
1135 unsigned VecNumElts
=
1136 cast
<FixedVectorType
>(Col
->getType())->getNumElements();
1137 for (; i
< I
+ BlockNumElts
; i
++)
1138 Mask
.push_back(i
- I
+ VecNumElts
);
1140 for (; i
< VecNumElts
; i
++)
1143 return Builder
.CreateShuffleVector(Col
, Block
, Mask
);
1146 Value
*createMulAdd(Value
*Sum
, Value
*A
, Value
*B
, bool UseFPOp
,
1147 IRBuilder
<> &Builder
, bool AllowContraction
,
1148 unsigned &NumComputeOps
) {
1149 NumComputeOps
+= getNumOps(A
->getType());
1151 return UseFPOp
? Builder
.CreateFMul(A
, B
) : Builder
.CreateMul(A
, B
);
1154 if (AllowContraction
) {
1155 // Use fmuladd for floating point operations and let the backend decide
1156 // if that's profitable.
1157 Function
*FMulAdd
= Intrinsic::getDeclaration(
1158 Func
.getParent(), Intrinsic::fmuladd
, A
->getType());
1159 return Builder
.CreateCall(FMulAdd
, {A
, B
, Sum
});
1161 NumComputeOps
+= getNumOps(A
->getType());
1162 Value
*Mul
= Builder
.CreateFMul(A
, B
);
1163 return Builder
.CreateFAdd(Sum
, Mul
);
1166 NumComputeOps
+= getNumOps(A
->getType());
1167 Value
*Mul
= Builder
.CreateMul(A
, B
);
1168 return Builder
.CreateAdd(Sum
, Mul
);
1171 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1172 /// users with shape information, there's nothing to do: they will use the
1173 /// cached value when they are lowered. For other users, \p Matrix is
1174 /// flattened and the uses are updated to use it. Also marks \p Inst for
1176 void finalizeLowering(Instruction
*Inst
, MatrixTy Matrix
,
1177 IRBuilder
<> &Builder
) {
1178 auto inserted
= Inst2ColumnMatrix
.insert(std::make_pair(Inst
, Matrix
));
1180 assert(inserted
.second
&& "multiple matrix lowering mapping");
1182 ToRemove
.push_back(Inst
);
1183 Value
*Flattened
= nullptr;
1184 for (Use
&U
: llvm::make_early_inc_range(Inst
->uses())) {
1185 if (ShapeMap
.find(U
.getUser()) == ShapeMap
.end()) {
1187 Flattened
= Matrix
.embedInVector(Builder
);
1193 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1196 /// We can fold a transpose into the operand that is used to extract scalars.
1197 /// This is the first operands with row-major and the second with
1198 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1199 /// operand is transposed.
1200 void emitMatrixMultiply(MatrixTy
&Result
, const MatrixTy
&A
,
1201 const MatrixTy
&B
, IRBuilder
<> &Builder
, bool IsTiled
,
1202 bool IsScalarMatrixTransposed
, FastMathFlags FMF
) {
1203 const unsigned VF
= std::max
<unsigned>(
1204 TTI
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector
)
1206 Result
.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1208 unsigned R
= Result
.getNumRows();
1209 unsigned C
= Result
.getNumColumns();
1210 unsigned M
= A
.getNumColumns();
1212 bool IsFP
= Result
.getElementType()->isFloatingPointTy();
1213 assert(A
.isColumnMajor() == B
.isColumnMajor() &&
1214 Result
.isColumnMajor() == A
.isColumnMajor() &&
1215 "operands must agree on matrix layout");
1216 unsigned NumComputeOps
= 0;
1218 Builder
.setFastMathFlags(FMF
);
1220 if (A
.isColumnMajor()) {
1221 // Multiply columns from the first operand with scalars from the second
1222 // operand. Then move along the K axes and accumulate the columns. With
1223 // this the adds can be vectorized without reassociation.
1224 for (unsigned J
= 0; J
< C
; ++J
) {
1225 unsigned BlockSize
= VF
;
1226 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1227 bool isSumZero
= isa
<ConstantAggregateZero
>(Result
.getColumn(J
));
1229 for (unsigned I
= 0; I
< R
; I
+= BlockSize
) {
1230 // Gradually lower the vectorization factor to cover the remainder.
1231 while (I
+ BlockSize
> R
)
1234 Value
*Sum
= IsTiled
? Result
.extractVector(I
, J
, BlockSize
, Builder
)
1236 for (unsigned K
= 0; K
< M
; ++K
) {
1237 Value
*L
= A
.extractVector(I
, K
, BlockSize
, Builder
);
1238 Value
*RH
= Builder
.CreateExtractElement(
1239 B
.getColumn(IsScalarMatrixTransposed
? K
: J
),
1240 IsScalarMatrixTransposed
? J
: K
);
1241 Value
*Splat
= Builder
.CreateVectorSplat(BlockSize
, RH
, "splat");
1243 createMulAdd(isSumZero
&& K
== 0 ? nullptr : Sum
, L
, Splat
,
1244 IsFP
, Builder
, FMF
.allowContract(), NumComputeOps
);
1247 insertVector(Result
.getVector(J
), I
, Sum
, Builder
));
1251 // Multiply rows from the second operand with scalars from the first
1252 // operand. Then move along the K axes and accumulate the rows. With this
1253 // the adds can be vectorized without reassociation.
1254 for (unsigned I
= 0; I
< R
; ++I
) {
1255 unsigned BlockSize
= VF
;
1256 bool isSumZero
= isa
<ConstantAggregateZero
>(Result
.getRow(I
));
1257 for (unsigned J
= 0; J
< C
; J
+= BlockSize
) {
1258 // Gradually lower the vectorization factor to cover the remainder.
1259 while (J
+ BlockSize
> C
)
1262 Value
*Sum
= nullptr;
1263 for (unsigned K
= 0; K
< M
; ++K
) {
1264 Value
*R
= B
.extractVector(K
, J
, BlockSize
, Builder
);
1265 Value
*LH
= Builder
.CreateExtractElement(
1266 A
.getVector(IsScalarMatrixTransposed
? K
: I
),
1267 IsScalarMatrixTransposed
? I
: K
);
1268 Value
*Splat
= Builder
.CreateVectorSplat(BlockSize
, LH
, "splat");
1270 createMulAdd(isSumZero
&& K
== 0 ? nullptr : Sum
, Splat
, R
,
1271 IsFP
, Builder
, FMF
.allowContract(), NumComputeOps
);
1274 insertVector(Result
.getVector(I
), J
, Sum
, Builder
));
1278 Result
.addNumComputeOps(NumComputeOps
);
1281 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1282 /// copying it to a new location. This new or otherwise the original location
1284 Value
*getNonAliasingPointer(LoadInst
*Load
, StoreInst
*Store
,
1286 MemoryLocation StoreLoc
= MemoryLocation::get(Store
);
1287 MemoryLocation LoadLoc
= MemoryLocation::get(Load
);
1289 // If we can statically determine noalias we're good.
1290 if (AA
->isNoAlias(LoadLoc
, StoreLoc
))
1291 return Load
->getPointerOperand();
1293 // Create code to check if the memory locations of the Load and Store
1294 // overlap and if they do, copy Load's operand to a new buffer.
1296 // First, create new blocks for 2n part of the check and the copy.
1297 BasicBlock
*Check0
= MatMul
->getParent();
1298 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1299 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1300 // as we adjust Check0 and Check1's branches.
1301 SmallVector
<DominatorTree::UpdateType
, 4> DTUpdates
;
1302 for (BasicBlock
*Succ
: successors(Check0
))
1303 DTUpdates
.push_back({DT
->Delete
, Check0
, Succ
});
1305 BasicBlock
*Check1
=
1306 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1307 nullptr, "alias_cont");
1309 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1311 BasicBlock
*Fusion
=
1312 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1313 nullptr, "no_alias");
1315 // Check if the loaded memory location begins before the end of the store
1316 // location. If the condition holds, they might overlap, otherwise they are
1317 // guaranteed to not overlap.
1318 IRBuilder
<> Builder(MatMul
);
1319 Check0
->getTerminator()->eraseFromParent();
1320 Builder
.SetInsertPoint(Check0
);
1321 Type
*IntPtrTy
= Builder
.getIntPtrTy(Load
->getModule()->getDataLayout());
1322 Value
*StoreBegin
= Builder
.CreatePtrToInt(
1323 const_cast<Value
*>(StoreLoc
.Ptr
), IntPtrTy
, "store.begin");
1324 Value
*StoreEnd
= Builder
.CreateAdd(
1325 StoreBegin
, ConstantInt::get(IntPtrTy
, StoreLoc
.Size
.getValue()),
1326 "store.end", true, true);
1327 Value
*LoadBegin
= Builder
.CreatePtrToInt(const_cast<Value
*>(LoadLoc
.Ptr
),
1328 IntPtrTy
, "load.begin");
1329 Builder
.CreateCondBr(Builder
.CreateICmpULT(LoadBegin
, StoreEnd
), Check1
,
1332 // Check if the store begins before the end of the load location. If the
1333 // condition holds, they alias, otherwise they are guaranteed to not
1335 Check1
->getTerminator()->eraseFromParent();
1336 Builder
.SetInsertPoint(Check1
, Check1
->begin());
1337 Value
*LoadEnd
= Builder
.CreateAdd(
1338 LoadBegin
, ConstantInt::get(IntPtrTy
, LoadLoc
.Size
.getValue()),
1339 "load.end", true, true);
1340 Builder
.CreateCondBr(Builder
.CreateICmpULT(StoreBegin
, LoadEnd
), Copy
,
1343 // Copy load operand to new alloca.
1344 Builder
.SetInsertPoint(Copy
, Copy
->begin());
1346 Builder
.CreateAlloca(Load
->getType(), Load
->getPointerAddressSpace());
1347 Builder
.CreateMemCpy(NewLd
, NewLd
->getAlign(),
1348 Load
->getPointerOperand(), Load
->getAlign(),
1349 LoadLoc
.Size
.getValue());
1350 Builder
.SetInsertPoint(Fusion
, Fusion
->begin());
1351 PHINode
*PHI
= Builder
.CreatePHI(Load
->getPointerOperandType(), 3);
1352 PHI
->addIncoming(Load
->getPointerOperand(), Check0
);
1353 PHI
->addIncoming(Load
->getPointerOperand(), Check1
);
1354 PHI
->addIncoming(NewLd
, Copy
);
1357 DTUpdates
.push_back({DT
->Insert
, Check0
, Check1
});
1358 DTUpdates
.push_back({DT
->Insert
, Check0
, Fusion
});
1359 DTUpdates
.push_back({DT
->Insert
, Check1
, Copy
});
1360 DTUpdates
.push_back({DT
->Insert
, Check1
, Fusion
});
1361 DT
->applyUpdates(DTUpdates
);
1365 bool isFusionProfitable(CallInst
*MatMul
) {
1369 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1370 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1372 const unsigned R
= LShape
.NumRows
;
1373 const unsigned C
= RShape
.NumColumns
;
1374 const unsigned M
= LShape
.NumColumns
;
1375 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1377 const unsigned VF
= std::max
<unsigned>(
1378 TTI
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector
)
1380 EltType
->getPrimitiveSizeInBits().getFixedSize(),
1383 // Cost model for tiling
1385 // For tiling to be beneficial, we need reuse either along the R or
1386 // the C axis. We vectorize along the R axis so that means at least
1388 // TODO: Also consider cost of copying if operands alias.
1389 if (R
<= VF
&& C
== 1)
1391 // Then we need enough elements to exceed the number of vector
1392 // registers we have. Note that this is an oversimplification since
1393 // fusing also takes some extra loads which may exceed the number of
1394 // reloads necessary.
1395 unsigned Op0Regs
= (R
+ VF
- 1) / VF
* M
;
1396 unsigned Op1Regs
= (M
+ VF
- 1) / VF
* C
;
1397 return Op0Regs
+ Op1Regs
> TTI
.getNumberOfRegisters(true);
1400 MatrixTy
getZeroMatrix(Type
*EltType
, unsigned R
, unsigned C
) {
1402 auto *ColumType
= FixedVectorType::get(EltType
, R
);
1403 for (unsigned I
= 0; I
< C
; ++I
)
1404 Res
.addVector(ConstantAggregateZero::get(ColumType
));
1408 void createTiledLoops(CallInst
*MatMul
, Value
*LPtr
, ShapeInfo LShape
,
1409 Value
*RPtr
, ShapeInfo RShape
, StoreInst
*Store
) {
1410 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1412 // Create the main tiling loop nest.
1413 TileInfo
TI(LShape
.NumRows
, RShape
.NumColumns
, LShape
.NumColumns
, TileSize
);
1414 DomTreeUpdater
DTU(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
1415 Instruction
*InsertI
= cast
<Instruction
>(MatMul
);
1416 BasicBlock
*Start
= InsertI
->getParent();
1418 SplitBlock(InsertI
->getParent(), InsertI
, DT
, LI
, nullptr, "continue");
1419 IRBuilder
<> Builder(MatMul
);
1420 BasicBlock
*InnerBody
= TI
.CreateTiledLoops(Start
, End
, Builder
, DTU
, *LI
);
1423 FixedVectorType::get(MatMul
->getType()->getScalarType(), TileSize
);
1424 MatrixTy TileResult
;
1425 // Insert in the inner loop header.
1426 Builder
.SetInsertPoint(TI
.InnerLoopHeader
->getTerminator());
1427 // Create PHI nodes for the result columns to accumulate across iterations.
1428 SmallVector
<PHINode
*, 4> ColumnPhis
;
1429 for (unsigned I
= 0; I
< TileSize
; I
++) {
1430 auto *Phi
= Builder
.CreatePHI(TileVecTy
, 2, "result.vec." + Twine(I
));
1431 Phi
->addIncoming(ConstantAggregateZero::get(TileVecTy
),
1432 TI
.RowLoopHeader
->getSingleSuccessor());
1433 TileResult
.addVector(Phi
);
1434 ColumnPhis
.push_back(Phi
);
1437 // Insert in the inner loop body, which computes
1438 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1439 Builder
.SetInsertPoint(InnerBody
->getTerminator());
1440 // Load tiles of the operands.
1441 MatrixTy A
= loadMatrix(LPtr
, {}, false, LShape
, TI
.CurrentRow
, TI
.CurrentK
,
1442 {TileSize
, TileSize
}, EltType
, Builder
);
1443 MatrixTy B
= loadMatrix(RPtr
, {}, false, RShape
, TI
.CurrentK
, TI
.CurrentCol
,
1444 {TileSize
, TileSize
}, EltType
, Builder
);
1445 emitMatrixMultiply(TileResult
, A
, B
, Builder
, true, false,
1446 getFastMathFlags(MatMul
));
1447 // Store result after the inner loop is done.
1448 Builder
.SetInsertPoint(TI
.RowLoopLatch
->getTerminator());
1449 storeMatrix(TileResult
, Store
->getPointerOperand(), Store
->getAlign(),
1450 Store
->isVolatile(), {LShape
.NumRows
, RShape
.NumColumns
},
1451 TI
.CurrentRow
, TI
.CurrentCol
, EltType
, Builder
);
1453 for (unsigned I
= 0; I
< TileResult
.getNumVectors(); I
++)
1454 ColumnPhis
[I
]->addIncoming(TileResult
.getVector(I
), TI
.InnerLoopLatch
);
1456 // Force unrolling of a few iterations of the inner loop, to make sure there
1457 // is enough work per iteration.
1458 // FIXME: The unroller should make this decision directly instead, but
1459 // currently the cost-model is not up to the task.
1460 unsigned InnerLoopUnrollCount
= std::min(10u, LShape
.NumColumns
/ TileSize
);
1461 addStringMetadataToLoop(LI
->getLoopFor(TI
.InnerLoopHeader
),
1462 "llvm.loop.unroll.count", InnerLoopUnrollCount
);
1465 void emitSIMDTiling(CallInst
*MatMul
, LoadInst
*LoadOp0
, LoadInst
*LoadOp1
,
1467 SmallPtrSetImpl
<Instruction
*> &FusedInsts
) {
1468 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1469 "Tiling only supported for column-major matrixes at the moment!");
1470 if (!isFusionProfitable(MatMul
))
1473 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1474 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1476 const unsigned R
= LShape
.NumRows
;
1477 const unsigned C
= RShape
.NumColumns
;
1478 const unsigned M
= LShape
.NumColumns
;
1479 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1481 Value
*APtr
= getNonAliasingPointer(LoadOp0
, Store
, MatMul
);
1482 Value
*BPtr
= getNonAliasingPointer(LoadOp1
, Store
, MatMul
);
1483 Value
*CPtr
= Store
->getPointerOperand();
1485 if (TileUseLoops
&& (R
% TileSize
== 0 && C
% TileSize
== 0))
1486 createTiledLoops(MatMul
, APtr
, LShape
, BPtr
, RShape
, Store
);
1488 IRBuilder
<> Builder(Store
);
1489 for (unsigned J
= 0; J
< C
; J
+= TileSize
)
1490 for (unsigned I
= 0; I
< R
; I
+= TileSize
) {
1491 const unsigned TileR
= std::min(R
- I
, unsigned(TileSize
));
1492 const unsigned TileC
= std::min(C
- J
, unsigned(TileSize
));
1493 MatrixTy Res
= getZeroMatrix(EltType
, TileR
, TileC
);
1495 for (unsigned K
= 0; K
< M
; K
+= TileSize
) {
1496 const unsigned TileM
= std::min(M
- K
, unsigned(TileSize
));
1498 loadMatrix(APtr
, LoadOp0
->getAlign(), LoadOp0
->isVolatile(),
1499 LShape
, Builder
.getInt64(I
), Builder
.getInt64(K
),
1500 {TileR
, TileM
}, EltType
, Builder
);
1502 loadMatrix(BPtr
, LoadOp1
->getAlign(), LoadOp1
->isVolatile(),
1503 RShape
, Builder
.getInt64(K
), Builder
.getInt64(J
),
1504 {TileM
, TileC
}, EltType
, Builder
);
1505 emitMatrixMultiply(Res
, A
, B
, Builder
, true, false,
1506 getFastMathFlags(MatMul
));
1508 storeMatrix(Res
, CPtr
, Store
->getAlign(), Store
->isVolatile(), {R
, M
},
1509 Builder
.getInt64(I
), Builder
.getInt64(J
), EltType
,
1514 // Mark eliminated instructions as fused and remove them.
1515 FusedInsts
.insert(Store
);
1516 FusedInsts
.insert(MatMul
);
1517 Store
->eraseFromParent();
1518 MatMul
->eraseFromParent();
1519 if (LoadOp0
->hasNUses(0)) {
1520 FusedInsts
.insert(LoadOp0
);
1521 LoadOp0
->eraseFromParent();
1523 if (LoadOp1
!= LoadOp0
&& LoadOp1
->hasNUses(0)) {
1524 FusedInsts
.insert(LoadOp1
);
1525 LoadOp1
->eraseFromParent();
1529 /// Try to lower matrix multiply chains by fusing operations.
1531 /// Call finalizeLowering on lowered instructions. Instructions that are
1532 /// completely eliminated by fusion are added to \p FusedInsts.
1533 void LowerMatrixMultiplyFused(CallInst
*MatMul
,
1534 SmallPtrSetImpl
<Instruction
*> &FusedInsts
) {
1535 if (!FuseMatrix
|| !DT
)
1538 assert(AA
&& LI
&& "Analyses should be available");
1540 Value
*A
= MatMul
->getArgOperand(0);
1541 Value
*B
= MatMul
->getArgOperand(1);
1543 // We can fold the transpose into the operand that is used to fetch scalars.
1545 if (MatrixLayout
== MatrixLayoutTy::ColumnMajor
1546 ? match(B
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(T
)))
1547 : match(A
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(T
)))) {
1548 IRBuilder
<> Builder(MatMul
);
1549 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1550 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1551 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1552 const unsigned R
= LShape
.NumRows
;
1553 const unsigned M
= LShape
.NumColumns
;
1554 const unsigned C
= RShape
.NumColumns
;
1560 if (MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {
1561 MA
= getMatrix(A
, ShapeInfo(R
, M
), Builder
);
1562 MB
= getMatrix(T
, ShapeInfo(C
, M
), Builder
);
1565 MA
= getMatrix(T
, ShapeInfo(R
, M
), Builder
);
1566 MB
= getMatrix(B
, ShapeInfo(C
, M
), Builder
);
1570 // Initialize the output
1571 MatrixTy
Result(R
, C
, EltType
);
1573 emitMatrixMultiply(Result
, MA
, MB
, Builder
, false, true,
1574 getFastMathFlags(MatMul
));
1576 FusedInsts
.insert(MatMul
);
1577 if (Transpose
->hasOneUse()) {
1578 FusedInsts
.insert(cast
<Instruction
>(Transpose
));
1579 ToRemove
.push_back(cast
<Instruction
>(Transpose
));
1580 // TODO: add a fake entry for the folded instruction so that this is
1581 // included in the expression in the remark.
1582 Inst2ColumnMatrix
[Transpose
] = MatrixTy(M
, C
, EltType
);
1584 finalizeLowering(MatMul
, Result
, Builder
);
1588 if (!MatMul
->hasOneUse() || MatrixLayout
!= MatrixLayoutTy::ColumnMajor
)
1591 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1592 // since the single store user will be lowered as part of this.
1593 auto *LoadOp0
= dyn_cast
<LoadInst
>(A
);
1594 auto *LoadOp1
= dyn_cast
<LoadInst
>(B
);
1595 auto *Store
= dyn_cast
<StoreInst
>(*MatMul
->user_begin());
1596 if (LoadOp0
&& LoadOp1
&& Store
) {
1597 // The store address must dominate the MatMul instruction, otherwise
1598 // we create invalid IR.
1599 SetVector
<Value
*> WorkList
;
1600 WorkList
.insert(Store
->getOperand(1));
1601 SmallVector
<Instruction
*> ToHoist
;
1602 for (unsigned I
= 0; I
!= WorkList
.size(); ++I
) {
1603 Value
*Current
= WorkList
[I
];
1604 auto *CurrI
= dyn_cast
<Instruction
>(Current
);
1607 if (isa
<PHINode
>(CurrI
))
1609 if (DT
->dominates(CurrI
, MatMul
))
1611 if (CurrI
->mayHaveSideEffects() || CurrI
->mayReadFromMemory())
1613 ToHoist
.push_back(CurrI
);
1614 WorkList
.insert(CurrI
->op_begin(), CurrI
->op_end());
1617 sort(ToHoist
, [this](Instruction
*A
, Instruction
*B
) {
1618 return DT
->dominates(A
, B
);
1620 for (Instruction
*I
: ToHoist
)
1621 I
->moveBefore(MatMul
);
1623 emitSIMDTiling(MatMul
, LoadOp0
, LoadOp1
, Store
, FusedInsts
);
1628 /// Lowers llvm.matrix.multiply.
1629 void LowerMultiply(CallInst
*MatMul
) {
1630 IRBuilder
<> Builder(MatMul
);
1631 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1632 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1633 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1635 const MatrixTy
&Lhs
= getMatrix(MatMul
->getArgOperand(0), LShape
, Builder
);
1636 const MatrixTy
&Rhs
= getMatrix(MatMul
->getArgOperand(1), RShape
, Builder
);
1637 assert(Lhs
.getElementType() == Rhs
.getElementType() &&
1638 "Matrix multiply argument element types do not match.");
1640 const unsigned R
= LShape
.NumRows
;
1641 const unsigned C
= RShape
.NumColumns
;
1642 assert(LShape
.NumColumns
== RShape
.NumRows
);
1644 // Initialize the output
1645 MatrixTy
Result(R
, C
, EltType
);
1646 assert(Lhs
.getElementType() == Result
.getElementType() &&
1647 "Matrix multiply result element type does not match arguments.");
1649 emitMatrixMultiply(Result
, Lhs
, Rhs
, Builder
, false, false,
1650 getFastMathFlags(MatMul
));
1651 finalizeLowering(MatMul
, Result
, Builder
);
1654 /// Lowers llvm.matrix.transpose.
1655 void LowerTranspose(CallInst
*Inst
) {
1657 IRBuilder
<> Builder(Inst
);
1658 Value
*InputVal
= Inst
->getArgOperand(0);
1659 VectorType
*VectorTy
= cast
<VectorType
>(InputVal
->getType());
1660 ShapeInfo
ArgShape(Inst
->getArgOperand(1), Inst
->getArgOperand(2));
1661 MatrixTy InputMatrix
= getMatrix(InputVal
, ArgShape
, Builder
);
1663 const unsigned NewNumVecs
=
1664 InputMatrix
.isColumnMajor() ? ArgShape
.NumRows
: ArgShape
.NumColumns
;
1665 const unsigned NewNumElts
=
1666 InputMatrix
.isColumnMajor() ? ArgShape
.NumColumns
: ArgShape
.NumRows
;
1668 for (unsigned I
= 0; I
< NewNumVecs
; ++I
) {
1669 // Build a single result vector. First initialize it.
1670 Value
*ResultVector
= UndefValue::get(
1671 FixedVectorType::get(VectorTy
->getElementType(), NewNumElts
));
1672 // Go through the old elements and insert it into the resulting vector.
1673 for (auto J
: enumerate(InputMatrix
.vectors())) {
1674 Value
*Elt
= Builder
.CreateExtractElement(J
.value(), I
);
1675 // Row and column indices are transposed.
1677 Builder
.CreateInsertElement(ResultVector
, Elt
, J
.index());
1679 Result
.addVector(ResultVector
);
1682 // TODO: Improve estimate of operations needed for transposes. Currently we
1683 // just count the insertelement/extractelement instructions, but do not
1684 // account for later simplifications/combines.
1687 Result
.addNumComputeOps(2 * ArgShape
.NumRows
* ArgShape
.NumColumns
)
1688 .addNumExposedTransposes(1),
1692 /// Lower load instructions, if shape information is available.
1693 bool VisitLoad(LoadInst
*Inst
, Value
*Ptr
, IRBuilder
<> &Builder
) {
1694 auto I
= ShapeMap
.find(Inst
);
1695 if (I
== ShapeMap
.end())
1698 LowerLoad(Inst
, Ptr
, Inst
->getAlign(),
1699 Builder
.getInt64(I
->second
.getStride()), Inst
->isVolatile(),
1704 bool VisitStore(StoreInst
*Inst
, Value
*StoredVal
, Value
*Ptr
,
1705 IRBuilder
<> &Builder
) {
1706 auto I
= ShapeMap
.find(StoredVal
);
1707 if (I
== ShapeMap
.end())
1710 LowerStore(Inst
, StoredVal
, Ptr
, Inst
->getAlign(),
1711 Builder
.getInt64(I
->second
.getStride()), Inst
->isVolatile(),
1716 /// Lower binary operators, if shape information is available.
1717 bool VisitBinaryOperator(BinaryOperator
*Inst
) {
1718 auto I
= ShapeMap
.find(Inst
);
1719 if (I
== ShapeMap
.end())
1722 Value
*Lhs
= Inst
->getOperand(0);
1723 Value
*Rhs
= Inst
->getOperand(1);
1725 IRBuilder
<> Builder(Inst
);
1726 ShapeInfo
&Shape
= I
->second
;
1729 MatrixTy A
= getMatrix(Lhs
, Shape
, Builder
);
1730 MatrixTy B
= getMatrix(Rhs
, Shape
, Builder
);
1731 assert(A
.isColumnMajor() == B
.isColumnMajor() &&
1732 Result
.isColumnMajor() == A
.isColumnMajor() &&
1733 "operands must agree on matrix layout");
1735 Builder
.setFastMathFlags(getFastMathFlags(Inst
));
1737 // Helper to perform binary op on vectors.
1738 auto BuildVectorOp
= [&Builder
, Inst
](Value
*LHS
, Value
*RHS
) {
1739 switch (Inst
->getOpcode()) {
1740 case Instruction::Add
:
1741 return Builder
.CreateAdd(LHS
, RHS
);
1742 case Instruction::Mul
:
1743 return Builder
.CreateMul(LHS
, RHS
);
1744 case Instruction::Sub
:
1745 return Builder
.CreateSub(LHS
, RHS
);
1746 case Instruction::FAdd
:
1747 return Builder
.CreateFAdd(LHS
, RHS
);
1748 case Instruction::FMul
:
1749 return Builder
.CreateFMul(LHS
, RHS
);
1750 case Instruction::FSub
:
1751 return Builder
.CreateFSub(LHS
, RHS
);
1753 llvm_unreachable("Unsupported binary operator for matrix");
1757 for (unsigned I
= 0; I
< Shape
.getNumVectors(); ++I
)
1758 Result
.addVector(BuildVectorOp(A
.getVector(I
), B
.getVector(I
)));
1760 finalizeLowering(Inst
,
1761 Result
.addNumComputeOps(getNumOps(Result
.getVectorTy()) *
1762 Result
.getNumVectors()),
1767 /// Lower unary operators, if shape information is available.
1768 bool VisitUnaryOperator(UnaryOperator
*Inst
) {
1769 auto I
= ShapeMap
.find(Inst
);
1770 if (I
== ShapeMap
.end())
1773 Value
*Op
= Inst
->getOperand(0);
1775 IRBuilder
<> Builder(Inst
);
1776 ShapeInfo
&Shape
= I
->second
;
1779 MatrixTy M
= getMatrix(Op
, Shape
, Builder
);
1781 Builder
.setFastMathFlags(getFastMathFlags(Inst
));
1783 // Helper to perform unary op on vectors.
1784 auto BuildVectorOp
= [&Builder
, Inst
](Value
*Op
) {
1785 switch (Inst
->getOpcode()) {
1786 case Instruction::FNeg
:
1787 return Builder
.CreateFNeg(Op
);
1789 llvm_unreachable("Unsupported unary operator for matrix");
1793 for (unsigned I
= 0; I
< Shape
.getNumVectors(); ++I
)
1794 Result
.addVector(BuildVectorOp(M
.getVector(I
)));
1796 finalizeLowering(Inst
,
1797 Result
.addNumComputeOps(getNumOps(Result
.getVectorTy()) *
1798 Result
.getNumVectors()),
1803 /// Helper to linearize a matrix expression tree into a string. Currently
1804 /// matrix expressions are linarized by starting at an expression leaf and
1805 /// linearizing bottom up.
1806 struct ExprLinearizer
{
1807 unsigned LengthToBreak
= 100;
1809 raw_string_ostream Stream
;
1810 unsigned LineLength
= 0;
1811 const DataLayout
&DL
;
1813 /// Mapping from instructions to matrixes. It is used to identify
1814 /// matrix instructions.
1815 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
;
1817 /// Mapping from values to the leaves of all expressions that the value is
1819 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
;
1821 /// Set of matrix expressions in the scope of a given DISubprogram.
1822 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
;
1824 /// Leaf node of the expression to linearize.
1827 /// Used to keep track of sub-expressions that get reused while linearizing
1828 /// the expression. Re-used sub-expressions are marked as (reused).
1829 SmallPtrSet
<Value
*, 8> ReusedExprs
;
1831 ExprLinearizer(const DataLayout
&DL
,
1832 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
,
1833 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
,
1834 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
1836 : Str(), Stream(Str
), DL(DL
), Inst2Matrix(Inst2Matrix
), Shared(Shared
),
1837 ExprsInSubprogram(ExprsInSubprogram
), Leaf(Leaf
) {}
1839 void indent(unsigned N
) {
1841 for (unsigned i
= 0; i
< N
; i
++)
1850 void maybeIndent(unsigned Indent
) {
1851 if (LineLength
>= LengthToBreak
)
1854 if (LineLength
== 0)
1858 void write(StringRef S
) {
1859 LineLength
+= S
.size();
1863 Value
*getUnderlyingObjectThroughLoads(Value
*V
) {
1864 if (Value
*Ptr
= getPointerOperand(V
))
1865 return getUnderlyingObjectThroughLoads(Ptr
);
1866 else if (V
->getType()->isPointerTy())
1867 return getUnderlyingObject(V
);
1871 /// Returns true if \p V is a matrix value in the given subprogram.
1872 bool isMatrix(Value
*V
) const { return ExprsInSubprogram
.count(V
); }
1874 /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1876 void prettyPrintMatrixType(Value
*V
, raw_string_ostream
&SS
) {
1877 auto M
= Inst2Matrix
.find(V
);
1878 if (M
== Inst2Matrix
.end())
1881 SS
<< M
->second
.getNumRows();
1883 SS
<< M
->second
.getNumColumns();
1887 /// Write the called function name. Handles calls to llvm.matrix.*
1888 /// specially: we write the name, followed by the dimensions of the input
1889 /// matrixes, followed by the scalar type name.
1890 void writeFnName(CallInst
*CI
) {
1891 if (!CI
->getCalledFunction())
1892 write("<no called fn>");
1894 StringRef Name
= CI
->getCalledFunction()->getName();
1895 if (!Name
.startswith("llvm.matrix")) {
1899 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
);
1900 write(Intrinsic::getBaseName(II
->getIntrinsicID())
1901 .drop_front(StringRef("llvm.matrix.").size()));
1904 raw_string_ostream
SS(Tmp
);
1906 switch (II
->getIntrinsicID()) {
1907 case Intrinsic::matrix_multiply
:
1908 prettyPrintMatrixType(II
->getOperand(0), SS
);
1910 prettyPrintMatrixType(II
->getOperand(1), SS
);
1911 SS
<< "." << *II
->getType()->getScalarType();
1913 case Intrinsic::matrix_transpose
:
1914 prettyPrintMatrixType(II
->getOperand(0), SS
);
1915 SS
<< "." << *II
->getType()->getScalarType();
1917 case Intrinsic::matrix_column_major_load
:
1918 prettyPrintMatrixType(II
, SS
);
1919 SS
<< "." << *II
->getType()->getScalarType();
1921 case Intrinsic::matrix_column_major_store
:
1922 prettyPrintMatrixType(II
->getOperand(0), SS
);
1923 SS
<< "." << *II
->getOperand(0)->getType()->getScalarType();
1926 llvm_unreachable("Unhandled case");
1933 unsigned getNumShapeArgs(CallInst
*CI
) const {
1934 if (IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
)) {
1935 switch (II
->getIntrinsicID()) {
1936 case Intrinsic::matrix_multiply
:
1938 case Intrinsic::matrix_transpose
:
1940 case Intrinsic::matrix_column_major_load
:
1941 case Intrinsic::matrix_column_major_store
:
1950 /// Special printing for values: for pointers, we print if they refer to an
1951 /// (function) external address or a stack address, for other values we
1952 /// either print the constant or "scalar"/"matrix" for other values.
1953 void write(Value
*V
) {
1954 V
= getUnderlyingObjectThroughLoads(V
);
1955 if (V
->getType()->isPointerTy()) {
1956 if (isa
<AllocaInst
>(V
)) {
1957 Stream
<< "stack addr";
1958 LineLength
+= StringRef("stack addr").size();
1961 LineLength
+= StringRef("addr").size();
1963 if (!V
->getName().empty()) {
1964 Stream
<< " %" << V
->getName() << "";
1965 LineLength
+= V
->getName().size() + 2;
1971 raw_string_ostream
TmpStream(Tmp
);
1973 if (auto *CI
= dyn_cast
<ConstantInt
>(V
))
1974 TmpStream
<< CI
->getValue();
1975 else if (isa
<Constant
>(V
))
1976 TmpStream
<< "constant";
1979 TmpStream
<< "matrix";
1981 TmpStream
<< "scalar";
1984 Tmp
= std::string(StringRef(Tmp
).trim());
1985 LineLength
+= Tmp
.size();
1989 /// Linearize expression \p Expr starting at an indentation of \p Indent.
1990 /// Expressions that are re-used multiple times are prefixed with (reused)
1991 /// at the re-used root instruction.
1992 void linearizeExpr(Value
*Expr
, unsigned Indent
, bool ParentReused
,
1993 bool ParentShared
) {
1994 auto *I
= cast
<Instruction
>(Expr
);
1995 maybeIndent(Indent
);
1996 SmallVector
<Value
*, 8> Ops
;
1998 // Is Expr shared with other expression leaves?
1999 bool ExprShared
= false;
2001 // Deal with shared subtrees. Mark them as shared, if required.
2002 if (!ParentShared
) {
2003 auto SI
= Shared
.find(Expr
);
2004 assert(SI
!= Shared
.end() && SI
->second
.count(Leaf
));
2006 for (Value
*S
: SI
->second
) {
2009 DebugLoc DL
= cast
<Instruction
>(S
)->getDebugLoc();
2010 write("shared with remark at line " + std::to_string(DL
.getLine()) +
2011 " column " + std::to_string(DL
.getCol()) + " (");
2013 ExprShared
= SI
->second
.size() > 1;
2016 bool Reused
= !ReusedExprs
.insert(Expr
).second
;
2017 if (Reused
&& !ParentReused
)
2020 if (auto *CI
= dyn_cast
<CallInst
>(I
)) {
2023 Ops
.append(CI
->arg_begin(), CI
->arg_end() - getNumShapeArgs(CI
));
2024 } else if (isa
<BitCastInst
>(Expr
)) {
2025 // Special case bitcasts, which are used to materialize matrixes from
2030 Ops
.append(I
->value_op_begin(), I
->value_op_end());
2031 write(std::string(I
->getOpcodeName()));
2034 write(std::string("("));
2036 unsigned NumOpsToBreak
= 1;
2037 if (match(Expr
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>()))
2040 for (Value
*Op
: Ops
) {
2041 if (Ops
.size() > NumOpsToBreak
)
2044 maybeIndent(Indent
+ 1);
2046 linearizeExpr(Op
, Indent
+ 1, Reused
, ExprShared
);
2049 if (Op
!= Ops
.back())
2056 const std::string
&getResult() {
2062 /// Generate remarks for matrix operations in a function. To generate remarks
2063 /// for matrix expressions, the following approach is used:
2064 /// 1. Use the inlined-at debug information to group matrix operations to the
2065 /// DISubprograms they are contained in.
2066 /// 2. Collect leaves of matrix expressions (done in
2067 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2068 // mapping. Leaves are lowered matrix instructions without other matrix
2069 // users (like stores) in the current subprogram.
2070 /// 3. For each leaf, create a remark containing a linearizied version of the
2071 /// matrix expression. The expression is linearized by a recursive
2072 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2073 /// that multiple leaves can share sub-expressions. Shared subexpressions
2074 /// are explicitly marked as shared().
2075 struct RemarkGenerator
{
2076 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
;
2077 OptimizationRemarkEmitter
&ORE
;
2079 const DataLayout
&DL
;
2081 RemarkGenerator(const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
,
2082 OptimizationRemarkEmitter
&ORE
, Function
&Func
)
2083 : Inst2Matrix(Inst2Matrix
), ORE(ORE
), Func(Func
),
2084 DL(Func
.getParent()->getDataLayout()) {}
2086 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2087 /// instructions in Inst2Matrix returning void or without any users in
2088 /// \p ExprsInSubprogram. Currently that should only include stores.
2089 SmallVector
<Value
*, 4>
2090 getExpressionLeaves(const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
) {
2091 SmallVector
<Value
*, 4> Leaves
;
2092 for (auto *Expr
: ExprsInSubprogram
)
2093 if (Expr
->getType()->isVoidTy() ||
2094 !any_of(Expr
->users(), [&ExprsInSubprogram
](User
*U
) {
2095 return ExprsInSubprogram
.count(U
);
2097 Leaves
.push_back(Expr
);
2101 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2102 /// to all visited expressions in \p Shared. Limit the matrix operations to
2103 /// the ones in \p ExprsInSubprogram.
2104 void collectSharedInfo(Value
*Leaf
, Value
*V
,
2105 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2106 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
) {
2108 if (!ExprsInSubprogram
.count(V
))
2111 auto I
= Shared
.insert({V
, {}});
2112 I
.first
->second
.insert(Leaf
);
2114 for (Value
*Op
: cast
<Instruction
>(V
)->operand_values())
2115 collectSharedInfo(Leaf
, Op
, ExprsInSubprogram
, Shared
);
2118 /// Calculate the number of exclusive and shared op counts for expression
2119 /// starting at \p V. Expressions used multiple times are counted once.
2120 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2121 std::pair
<OpInfoTy
, OpInfoTy
>
2122 sumOpInfos(Value
*Root
, SmallPtrSetImpl
<Value
*> &ReusedExprs
,
2123 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2124 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
) const {
2125 if (!ExprsInSubprogram
.count(Root
))
2128 // Already counted this expression. Stop.
2129 if (!ReusedExprs
.insert(Root
).second
)
2132 OpInfoTy SharedCount
;
2135 auto I
= Shared
.find(Root
);
2136 auto CM
= Inst2Matrix
.find(Root
);
2137 if (I
->second
.size() == 1)
2138 Count
= CM
->second
.getOpInfo();
2140 SharedCount
= CM
->second
.getOpInfo();
2142 for (Value
*Op
: cast
<Instruction
>(Root
)->operand_values()) {
2143 auto C
= sumOpInfos(Op
, ReusedExprs
, ExprsInSubprogram
, Shared
);
2145 SharedCount
+= C
.second
;
2147 return {Count
, SharedCount
};
2150 void emitRemarks() {
2151 if (!ORE
.allowExtraAnalysis(DEBUG_TYPE
))
2154 // Map matrix operations to their containting subprograms, by traversing
2155 // the inlinedAt chain. If the function does not have a DISubprogram, we
2156 // only map them to the containing function.
2157 MapVector
<DISubprogram
*, SmallVector
<Value
*, 8>> Subprog2Exprs
;
2158 for (auto &KV
: Inst2Matrix
) {
2159 if (Func
.getSubprogram()) {
2160 auto *I
= cast
<Instruction
>(KV
.first
);
2161 DILocation
*Context
= I
->getDebugLoc();
2164 Subprog2Exprs
.insert({getSubprogram(Context
->getScope()), {}});
2165 I
.first
->second
.push_back(KV
.first
);
2166 Context
= DebugLoc(Context
).getInlinedAt();
2169 auto I
= Subprog2Exprs
.insert({nullptr, {}});
2170 I
.first
->second
.push_back(KV
.first
);
2173 for (auto &KV
: Subprog2Exprs
) {
2174 SmallSetVector
<Value
*, 32> ExprsInSubprogram(KV
.second
.begin(),
2176 auto Leaves
= getExpressionLeaves(ExprsInSubprogram
);
2178 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> Shared
;
2179 for (Value
*Leaf
: Leaves
)
2180 collectSharedInfo(Leaf
, Leaf
, ExprsInSubprogram
, Shared
);
2182 // Generate remarks for each leaf.
2183 for (auto *L
: Leaves
) {
2185 DebugLoc Loc
= cast
<Instruction
>(L
)->getDebugLoc();
2186 DILocation
*Context
= cast
<Instruction
>(L
)->getDebugLoc();
2188 if (getSubprogram(Context
->getScope()) == KV
.first
) {
2192 Context
= DebugLoc(Context
).getInlinedAt();
2195 SmallPtrSet
<Value
*, 8> ReusedExprs
;
2196 OpInfoTy Counts
, SharedCounts
;
2197 std::tie(Counts
, SharedCounts
) =
2198 sumOpInfos(L
, ReusedExprs
, ExprsInSubprogram
, Shared
);
2200 OptimizationRemark
Rem(DEBUG_TYPE
, "matrix-lowered", Loc
,
2201 cast
<Instruction
>(L
)->getParent());
2203 Rem
<< "Lowered with ";
2204 Rem
<< ore::NV("NumStores", Counts
.NumStores
) << " stores, "
2205 << ore::NV("NumLoads", Counts
.NumLoads
) << " loads, "
2206 << ore::NV("NumComputeOps", Counts
.NumComputeOps
)
2208 << ore::NV("NumExposedTransposes", Counts
.NumExposedTransposes
)
2209 << " exposed transposes";
2211 if (SharedCounts
.NumStores
> 0 || SharedCounts
.NumLoads
> 0 ||
2212 SharedCounts
.NumComputeOps
> 0) {
2213 Rem
<< ",\nadditionally "
2214 << ore::NV("NumStores", SharedCounts
.NumStores
) << " stores, "
2215 << ore::NV("NumLoads", SharedCounts
.NumLoads
) << " loads, "
2216 << ore::NV("NumFPOps", SharedCounts
.NumComputeOps
)
2218 << " are shared with other expressions";
2221 Rem
<< ("\n" + linearize(L
, Shared
, ExprsInSubprogram
, DL
));
2229 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
,
2230 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2231 const DataLayout
&DL
) {
2232 ExprLinearizer
Lin(DL
, Inst2Matrix
, Shared
, ExprsInSubprogram
, L
);
2233 Lin
.linearizeExpr(L
, 0, false, false);
2234 return Lin
.getResult();
2240 PreservedAnalyses
LowerMatrixIntrinsicsPass::run(Function
&F
,
2241 FunctionAnalysisManager
&AM
) {
2242 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
2243 OptimizationRemarkEmitter
*ORE
= nullptr;
2244 AAResults
*AA
= nullptr;
2245 DominatorTree
*DT
= nullptr;
2246 LoopInfo
*LI
= nullptr;
2249 ORE
= &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
2250 AA
= &AM
.getResult
<AAManager
>(F
);
2251 DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
2252 LI
= &AM
.getResult
<LoopAnalysis
>(F
);
2255 LowerMatrixIntrinsics
LMT(F
, TTI
, AA
, DT
, LI
, ORE
);
2257 PreservedAnalyses PA
;
2259 PA
.preserve
<LoopAnalysis
>();
2260 PA
.preserve
<DominatorTreeAnalysis
>();
2264 return PreservedAnalyses::all();
2269 class LowerMatrixIntrinsicsLegacyPass
: public FunctionPass
{
2273 LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID
) {
2274 initializeLowerMatrixIntrinsicsLegacyPassPass(
2275 *PassRegistry::getPassRegistry());
2278 bool runOnFunction(Function
&F
) override
{
2279 auto &TTI
= getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
2280 auto &ORE
= getAnalysis
<OptimizationRemarkEmitterWrapperPass
>().getORE();
2281 auto &AA
= getAnalysis
<AAResultsWrapperPass
>().getAAResults();
2282 auto &DT
= getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
2283 auto &LI
= getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
2284 LowerMatrixIntrinsics
LMT(F
, TTI
, &AA
, &DT
, &LI
, &ORE
);
2285 bool C
= LMT
.Visit();
2289 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
2290 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
2291 AU
.addRequired
<OptimizationRemarkEmitterWrapperPass
>();
2292 AU
.addRequired
<AAResultsWrapperPass
>();
2293 AU
.addRequired
<DominatorTreeWrapperPass
>();
2294 AU
.addPreserved
<DominatorTreeWrapperPass
>();
2295 AU
.addRequired
<LoopInfoWrapperPass
>();
2296 AU
.addPreserved
<LoopInfoWrapperPass
>();
2301 static const char pass_name
[] = "Lower the matrix intrinsics";
2302 char LowerMatrixIntrinsicsLegacyPass::ID
= 0;
2303 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass
, DEBUG_TYPE
, pass_name
,
2305 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass
)
2306 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass
)
2307 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
2308 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
2309 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass
, DEBUG_TYPE
, pass_name
,
2312 Pass
*llvm::createLowerMatrixIntrinsicsPass() {
2313 return new LowerMatrixIntrinsicsLegacyPass();
2318 /// A lightweight version of the matrix lowering pass that only requires TTI.
2319 /// Advanced features that require DT, AA or ORE like tiling are disabled. This
2320 /// is used to lower matrix intrinsics if the main lowering pass is not run, for
2321 /// example with -O0.
2322 class LowerMatrixIntrinsicsMinimalLegacyPass
: public FunctionPass
{
2326 LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID
) {
2327 initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
2328 *PassRegistry::getPassRegistry());
2331 bool runOnFunction(Function
&F
) override
{
2332 auto &TTI
= getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
2333 LowerMatrixIntrinsics
LMT(F
, TTI
, nullptr, nullptr, nullptr, nullptr);
2334 bool C
= LMT
.Visit();
2338 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
2339 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
2340 AU
.setPreservesCFG();
2345 static const char pass_name_minimal
[] = "Lower the matrix intrinsics (minimal)";
2346 char LowerMatrixIntrinsicsMinimalLegacyPass::ID
= 0;
2347 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass
,
2348 "lower-matrix-intrinsics-minimal", pass_name_minimal
,
2350 INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass
,
2351 "lower-matrix-intrinsics-minimal", pass_name_minimal
, false,
2354 Pass
*llvm::createLowerMatrixIntrinsicsMinimalPass() {
2355 return new LowerMatrixIntrinsicsMinimalLegacyPass();