1 //===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Lower matrix intrinsics to vector operations.
13 // * Support more cases, e.g. multiply-add, multiply-sub, operands/results
15 // * Improve cost-modeling, e.g. choose different number of rows/columns
16 // columns for tiles, consider cost of copies on alias.
18 //===----------------------------------------------------------------------===//
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/DomTreeUpdater.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/TargetTransformInfo.h"
29 #include "llvm/Analysis/ValueTracking.h"
30 #include "llvm/Analysis/VectorUtils.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/IntrinsicInst.h"
38 #include "llvm/IR/MatrixBuilder.h"
39 #include "llvm/IR/PatternMatch.h"
40 #include "llvm/Support/Alignment.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/Utils/LoopUtils.h"
45 #include "llvm/Transforms/Utils/MatrixUtils.h"
50 using namespace PatternMatch
;
52 #define DEBUG_TYPE "lower-matrix-intrinsics"
55 FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden
,
56 cl::desc("Enable/disable fusing matrix instructions."));
57 // TODO: Allow and use non-square tiles.
58 static cl::opt
<unsigned> TileSize(
59 "fuse-matrix-tile-size", cl::init(4), cl::Hidden
,
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
62 static cl::opt
<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
64 cl::desc("Generate loop nest for tiling."));
65 static cl::opt
<bool> ForceFusion(
66 "force-fuse-matrix", cl::init(false), cl::Hidden
,
67 cl::desc("Force matrix instruction fusion even if not profitable."));
68 static cl::opt
<bool> AllowContractEnabled(
69 "matrix-allow-contract", cl::init(false), cl::Hidden
,
70 cl::desc("Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
74 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden
,
75 cl::desc("Enable/disable matrix shape verification."),
78 enum class MatrixLayoutTy
{ ColumnMajor
, RowMajor
};
80 static cl::opt
<MatrixLayoutTy
> MatrixLayout(
81 "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor
),
82 cl::desc("Sets the default matrix layout"),
83 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor
, "column-major",
84 "Use column-major layout"),
85 clEnumValN(MatrixLayoutTy::RowMajor
, "row-major",
86 "Use row-major layout")));
88 static cl::opt
<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
91 /// Helper function to either return Scope, if it is a subprogram or the
92 /// attached subprogram for a local scope.
93 static DISubprogram
*getSubprogram(DIScope
*Scope
) {
94 if (auto *Subprogram
= dyn_cast
<DISubprogram
>(Scope
))
96 return cast
<DILocalScope
>(Scope
)->getSubprogram();
99 /// Erase \p V from \p BB and move \II forward to avoid invalidating
101 static void eraseFromParentAndMove(Value
*V
, BasicBlock::reverse_iterator
&II
,
103 auto *Inst
= cast
<Instruction
>(V
);
104 // Still used, don't erase.
105 if (!Inst
->use_empty())
107 if (II
!= BB
.rend() && Inst
== &*II
)
109 Inst
->eraseFromParent();
112 /// Return true if V is a splat of a value (which is used when multiplying a
113 /// matrix with a scalar).
114 static bool isSplat(Value
*V
) {
115 if (auto *SV
= dyn_cast
<ShuffleVectorInst
>(V
))
116 return SV
->isZeroEltSplat();
120 /// Match any mul operation (fp or integer).
121 template <typename LTy
, typename RTy
>
122 auto m_AnyMul(const LTy
&L
, const RTy
&R
) {
123 return m_CombineOr(m_Mul(L
, R
), m_FMul(L
, R
));
126 /// Match any add operation (fp or integer).
127 template <typename LTy
, typename RTy
>
128 auto m_AnyAdd(const LTy
&L
, const RTy
&R
) {
129 return m_CombineOr(m_Add(L
, R
), m_FAdd(L
, R
));
134 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
135 // the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
136 // assuming \p Stride elements between start two consecutive vectors.
137 // \p Stride must be >= \p NumElements.
138 // For column-major matrixes, the function computes the address of a column
139 // vectors and \p NumElements must be set to the number of elements in a column
140 // (= number of rows of the matrix). For row-major matrixes, the function
141 // computes the address of a row vector and \p NumElements must be set to the
142 // number of elements in a column (= number of columns of the matrix).
144 // Consider a 4x4 matrix in column-mjaor layout like below
147 // 0 v_0_0 v_0_1 v_0_2 v_0_3
148 // 1 v_1_0 v_1_1 v_1_2 v_1_3
149 // 2 v_2_0 v_2_1 v_2_2 v_2_3
150 // 3 v_3_0 v_3_1 v_3_2 v_3_3
152 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
153 // we need a pointer to the first element of the submatrix as base pointer.
154 // Then we can use computeVectorAddr to compute the addresses for the columns
155 // of the sub-matrix.
157 // Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
158 // -> just returns Base
159 // Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
160 // -> returns Base + (1 * 4)
161 // Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
162 // -> returns Base + (2 * 4)
164 // The graphic below illustrates the number of elements in a column (marked
165 // with |) and the number of skipped elements (marked with }).
167 // v_0_0 v_0_1 {v_0_2 {v_0_3
170 // v_1_0 |v_1_1 |v_1_2 |v_1_3
171 // v_2_0 |v_2_1 |v_2_2 |v_2_3
172 // v_3_0 {v_3_1 {v_3_2 v_3_3
174 Value
*computeVectorAddr(Value
*BasePtr
, Value
*VecIdx
, Value
*Stride
,
175 unsigned NumElements
, Type
*EltType
,
176 IRBuilder
<> &Builder
) {
178 assert((!isa
<ConstantInt
>(Stride
) ||
179 cast
<ConstantInt
>(Stride
)->getZExtValue() >= NumElements
) &&
180 "Stride must be >= the number of elements in the result vector.");
182 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
183 Value
*VecStart
= Builder
.CreateMul(VecIdx
, Stride
, "vec.start");
185 // Get pointer to the start of the selected vector. Skip GEP creation,
186 // if we select vector 0.
187 if (isa
<ConstantInt
>(VecStart
) && cast
<ConstantInt
>(VecStart
)->isZero())
190 VecStart
= Builder
.CreateGEP(EltType
, BasePtr
, VecStart
, "vec.gep");
195 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
197 /// Currently, the lowering for each matrix intrinsic is done as follows:
198 /// 1. Propagate the shape information from intrinsics to connected
200 /// 2. Lower instructions with shape information (assuming column-major layout).
201 /// The lowering works similarly using row-major layout.
202 /// 2.1. Get column vectors for each argument. If we already lowered the
203 /// definition of an argument, use the produced column vectors directly.
204 /// If not, split the operand vector containing an embedded matrix into
205 /// a set of column vectors,
206 /// 2.2. Lower the instruction in terms of column major operations, which
207 /// yields a set of column vectors containing result matrix. Note that we
208 /// lower all instructions that have shape information. Besides the
209 /// intrinsics, this includes stores for example.
210 /// 2.3. Update uses of the lowered instruction. If we have shape information
211 /// for a user, there is nothing to do, as we will look up the result
212 /// column matrix when lowering the user. For other uses, we embed the
213 /// result matrix in a flat vector and update the use.
214 /// 2.4. Cache the result column matrix for the instruction we lowered
215 /// 3. After we lowered all instructions in a function, remove the now
216 /// obsolete instructions.
218 class LowerMatrixIntrinsics
{
220 const DataLayout
&DL
;
221 const TargetTransformInfo
&TTI
;
225 OptimizationRemarkEmitter
*ORE
;
227 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
229 /// Number of stores emitted to generate this matrix.
230 unsigned NumStores
= 0;
231 /// Number of loads emitted to generate this matrix.
232 unsigned NumLoads
= 0;
233 /// Number of compute operations emitted to generate this matrix.
234 unsigned NumComputeOps
= 0;
235 /// Most of the time transposes can be fused with matrix multiplies or can
236 /// be folded away via algebraic simplifications. This is the number of
237 /// transposes that we failed to make "free" via such optimizations.
238 unsigned NumExposedTransposes
= 0;
240 OpInfoTy
&operator+=(const OpInfoTy
&RHS
) {
241 NumStores
+= RHS
.NumStores
;
242 NumLoads
+= RHS
.NumLoads
;
243 NumComputeOps
+= RHS
.NumComputeOps
;
244 NumExposedTransposes
+= RHS
.NumExposedTransposes
;
249 /// Wrapper class representing a matrix as a set of vectors, either in row or
250 /// column major layout. All vectors must have the same vector type.
252 SmallVector
<Value
*, 16> Vectors
;
256 bool IsColumnMajor
= true;
259 MatrixTy() : IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
260 MatrixTy(ArrayRef
<Value
*> Vectors
)
261 : Vectors(Vectors
.begin(), Vectors
.end()),
262 IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
263 MatrixTy(unsigned NumRows
, unsigned NumColumns
, Type
*EltTy
)
264 : IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {
266 unsigned D
= isColumnMajor() ? NumColumns
: NumRows
;
267 for (unsigned J
= 0; J
< D
; ++J
)
268 addVector(PoisonValue::get(FixedVectorType::get(
269 EltTy
, isColumnMajor() ? NumRows
: NumColumns
)));
272 Value
*getVector(unsigned i
) const { return Vectors
[i
]; }
273 Value
*getColumn(unsigned i
) const {
274 assert(isColumnMajor() && "only supported for column-major matrixes");
277 Value
*getRow(unsigned i
) const {
278 assert(!isColumnMajor() && "only supported for row-major matrixes");
282 void setVector(unsigned i
, Value
*V
) { Vectors
[i
] = V
; }
284 Type
*getElementType() const { return getVectorTy()->getElementType(); }
286 unsigned getNumVectors() const {
288 return getNumColumns();
292 unsigned getNumColumns() const {
294 return Vectors
.size();
296 assert(Vectors
.size() > 0 && "Cannot call getNumRows without columns");
297 return cast
<FixedVectorType
>(Vectors
[0]->getType())->getNumElements();
300 unsigned getNumRows() const {
301 if (isColumnMajor()) {
302 assert(Vectors
.size() > 0 && "Cannot call getNumRows without columns");
303 return cast
<FixedVectorType
>(Vectors
[0]->getType())->getNumElements();
305 return Vectors
.size();
308 void addVector(Value
*V
) { Vectors
.push_back(V
); }
309 VectorType
*getColumnTy() {
310 assert(isColumnMajor() && "only supported for column-major matrixes");
311 return getVectorTy();
314 VectorType
*getVectorTy() const {
315 return cast
<VectorType
>(Vectors
[0]->getType());
318 iterator_range
<SmallVector
<Value
*, 8>::iterator
> columns() {
319 assert(isColumnMajor() &&
320 "columns() only supported for column-major matrixes");
321 return make_range(Vectors
.begin(), Vectors
.end());
324 iterator_range
<SmallVector
<Value
*, 8>::iterator
> vectors() {
325 return make_range(Vectors
.begin(), Vectors
.end());
328 /// Embed the vectors of the matrix into a flat vector by concatenating
330 Value
*embedInVector(IRBuilder
<> &Builder
) const {
331 return Vectors
.size() == 1 ? Vectors
[0]
332 : concatenateVectors(Builder
, Vectors
);
335 MatrixTy
&addNumLoads(unsigned N
) {
336 OpInfo
.NumLoads
+= N
;
340 void setNumLoads(unsigned N
) { OpInfo
.NumLoads
= N
; }
342 MatrixTy
&addNumStores(unsigned N
) {
343 OpInfo
.NumStores
+= N
;
347 MatrixTy
&addNumExposedTransposes(unsigned N
) {
348 OpInfo
.NumExposedTransposes
+= N
;
352 MatrixTy
&addNumComputeOps(unsigned N
) {
353 OpInfo
.NumComputeOps
+= N
;
357 unsigned getNumStores() const { return OpInfo
.NumStores
; }
358 unsigned getNumLoads() const { return OpInfo
.NumLoads
; }
359 unsigned getNumComputeOps() const { return OpInfo
.NumComputeOps
; }
361 const OpInfoTy
&getOpInfo() const { return OpInfo
; }
363 bool isColumnMajor() const { return IsColumnMajor
; }
365 unsigned getStride() const {
368 return getNumColumns();
371 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
372 /// matrix is column-major, the result vector is extracted from a column
373 /// vector, otherwise from a row vector.
374 Value
*extractVector(unsigned I
, unsigned J
, unsigned NumElts
,
375 IRBuilder
<> &Builder
) const {
376 Value
*Vec
= isColumnMajor() ? getColumn(J
) : getRow(I
);
377 assert(cast
<FixedVectorType
>(Vec
->getType())->getNumElements() >=
379 "Extracted vector will contain poison values");
380 return Builder
.CreateShuffleVector(
381 Vec
, createSequentialMask(isColumnMajor() ? I
: J
, NumElts
, 0),
392 ShapeInfo(unsigned NumRows
= 0, unsigned NumColumns
= 0)
393 : NumRows(NumRows
), NumColumns(NumColumns
),
394 IsColumnMajor(MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {}
396 ShapeInfo(Value
*NumRows
, Value
*NumColumns
)
397 : ShapeInfo(cast
<ConstantInt
>(NumRows
)->getZExtValue(),
398 cast
<ConstantInt
>(NumColumns
)->getZExtValue()) {}
400 bool operator==(const ShapeInfo
&other
) {
401 return NumRows
== other
.NumRows
&& NumColumns
== other
.NumColumns
;
403 bool operator!=(const ShapeInfo
&other
) { return !(*this == other
); }
405 /// Returns true if shape-information is defined, meaning both dimensions
407 operator bool() const {
408 assert(NumRows
== 0 || NumColumns
!= 0);
412 unsigned getStride() const {
418 unsigned getNumVectors() const {
424 /// Returns the transposed shape.
425 ShapeInfo
t() const { return ShapeInfo(NumColumns
, NumRows
); }
428 /// Maps instructions to their shape information. The shape information
429 /// describes the shape to be used while lowering. This matches the shape of
430 /// the result value of the instruction, with the only exceptions being store
431 /// instructions and the matrix_column_major_store intrinsics. For those, the
432 /// shape information indicates that those instructions should be lowered
433 /// using shape information as well. A ValueMap is used so that when
434 /// sub-passes like optimizeTransposes performs RAUW the map stays
436 ValueMap
<Value
*, ShapeInfo
> ShapeMap
;
438 /// List of instructions to remove. While lowering, we are not replacing all
439 /// users of a lowered instruction, if shape information is available and
440 /// those need to be removed after we finished lowering.
441 SmallVector
<Instruction
*, 16> ToRemove
;
443 /// Map from instructions to their produced column matrix.
444 MapVector
<Value
*, MatrixTy
> Inst2ColumnMatrix
;
447 static FastMathFlags
getFastMathFlags(Instruction
*Inst
) {
450 if (isa
<FPMathOperator
>(*Inst
))
451 FMF
= Inst
->getFastMathFlags();
453 FMF
.setAllowContract(AllowContractEnabled
|| FMF
.allowContract());
459 LowerMatrixIntrinsics(Function
&F
, TargetTransformInfo
&TTI
,
460 AliasAnalysis
*AA
, DominatorTree
*DT
, LoopInfo
*LI
,
461 OptimizationRemarkEmitter
*ORE
)
462 : Func(F
), DL(F
.getParent()->getDataLayout()), TTI(TTI
), AA(AA
), DT(DT
),
465 unsigned getNumOps(Type
*VT
) {
466 assert(isa
<VectorType
>(VT
) && "Expected vector type");
467 return getNumOps(VT
->getScalarType(),
468 cast
<FixedVectorType
>(VT
)->getNumElements());
471 /// Is this the minimal version executed in the backend pipelines.
472 bool isMinimal() const {
476 /// Return the estimated number of vector ops required for an operation on
478 unsigned getNumOps(Type
*ST
, unsigned N
) {
479 return std::ceil((ST
->getPrimitiveSizeInBits() * N
).getFixedValue() /
480 double(TTI
.getRegisterBitWidth(
481 TargetTransformInfo::RGK_FixedWidthVector
)
485 /// Return the set of vectors that a matrix value is lowered to.
487 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
488 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
490 MatrixTy
getMatrix(Value
*MatrixVal
, const ShapeInfo
&SI
,
491 IRBuilder
<> &Builder
) {
492 VectorType
*VType
= dyn_cast
<VectorType
>(MatrixVal
->getType());
493 assert(VType
&& "MatrixVal must be a vector type");
494 assert(cast
<FixedVectorType
>(VType
)->getNumElements() ==
495 SI
.NumRows
* SI
.NumColumns
&&
496 "The vector size must match the number of matrix elements");
498 // Check if we lowered MatrixVal using shape information. In that case,
499 // return the existing matrix, if it matches the requested shape
500 // information. If there is a mis-match, embed the result in a flat
501 // vector and split it later.
502 auto Found
= Inst2ColumnMatrix
.find(MatrixVal
);
503 if (Found
!= Inst2ColumnMatrix
.end()) {
504 MatrixTy
&M
= Found
->second
;
505 // Return the found matrix, if its shape matches the requested shape
507 if (SI
.NumRows
== M
.getNumRows() && SI
.NumColumns
== M
.getNumColumns())
510 MatrixVal
= M
.embedInVector(Builder
);
513 // Otherwise split MatrixVal.
514 SmallVector
<Value
*, 16> SplitVecs
;
515 for (unsigned MaskStart
= 0;
516 MaskStart
< cast
<FixedVectorType
>(VType
)->getNumElements();
517 MaskStart
+= SI
.getStride()) {
518 Value
*V
= Builder
.CreateShuffleVector(
519 MatrixVal
, createSequentialMask(MaskStart
, SI
.getStride(), 0),
521 SplitVecs
.push_back(V
);
527 /// If \p V already has a known shape return false. Otherwise set the shape
528 /// for instructions that support it.
529 bool setShapeInfo(Value
*V
, ShapeInfo Shape
) {
530 assert(Shape
&& "Shape not set");
531 if (isa
<UndefValue
>(V
) || !supportsShapeInfo(V
))
534 auto SIter
= ShapeMap
.find(V
);
535 if (SIter
!= ShapeMap
.end()) {
536 if (VerifyShapeInfo
&& (SIter
->second
.NumRows
!= Shape
.NumRows
||
537 SIter
->second
.NumColumns
!= Shape
.NumColumns
)) {
538 errs() << "Conflicting shapes (" << SIter
->second
.NumRows
<< "x"
539 << SIter
->second
.NumColumns
<< " vs " << Shape
.NumRows
<< "x"
540 << Shape
.NumColumns
<< ") for " << *V
<< "\n";
542 "Matrix shape verification failed, compilation aborted!");
545 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
546 << SIter
->second
.NumRows
<< " "
547 << SIter
->second
.NumColumns
<< " for " << *V
<< "\n");
551 ShapeMap
.insert({V
, Shape
});
552 LLVM_DEBUG(dbgs() << " " << Shape
.NumRows
<< " x " << Shape
.NumColumns
553 << " for " << *V
<< "\n");
557 bool isUniformShape(Value
*V
) {
558 Instruction
*I
= dyn_cast
<Instruction
>(V
);
562 switch (I
->getOpcode()) {
563 case Instruction::FAdd
:
564 case Instruction::FSub
:
565 case Instruction::FMul
: // Scalar multiply.
566 case Instruction::FNeg
:
567 case Instruction::Add
:
568 case Instruction::Mul
:
569 case Instruction::Sub
:
576 /// Returns true if shape information can be used for \p V. The supported
577 /// instructions must match the instructions that can be lowered by this pass.
578 bool supportsShapeInfo(Value
*V
) {
579 Instruction
*Inst
= dyn_cast
<Instruction
>(V
);
583 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(Inst
);
585 switch (II
->getIntrinsicID()) {
586 case Intrinsic::matrix_multiply
:
587 case Intrinsic::matrix_transpose
:
588 case Intrinsic::matrix_column_major_load
:
589 case Intrinsic::matrix_column_major_store
:
594 return isUniformShape(V
) || isa
<StoreInst
>(V
) || isa
<LoadInst
>(V
);
597 /// Propagate the shape information of instructions to their users.
598 /// The work list contains instructions for which we can compute the shape,
599 /// either based on the information provided by matrix intrinsics or known
600 /// shapes of operands.
601 SmallVector
<Instruction
*, 32>
602 propagateShapeForward(SmallVectorImpl
<Instruction
*> &WorkList
) {
603 SmallVector
<Instruction
*, 32> NewWorkList
;
604 // Pop an element for which we guaranteed to have at least one of the
605 // operand shapes. Add the shape for this and then add users to the work
607 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
608 while (!WorkList
.empty()) {
609 Instruction
*Inst
= WorkList
.pop_back_val();
611 // New entry, set the value and insert operands
612 bool Propagate
= false;
619 if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
620 m_Value(MatrixA
), m_Value(MatrixB
), m_Value(M
),
621 m_Value(N
), m_Value(K
)))) {
622 Propagate
= setShapeInfo(Inst
, {M
, K
});
623 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
624 m_Value(MatrixA
), m_Value(M
), m_Value(N
)))) {
626 Propagate
= setShapeInfo(Inst
, {N
, M
});
627 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_column_major_store
>(
628 m_Value(MatrixA
), m_Value(), m_Value(),
629 m_Value(), m_Value(M
), m_Value(N
)))) {
630 Propagate
= setShapeInfo(Inst
, {N
, M
});
631 } else if (match(Inst
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>(
632 m_Value(), m_Value(), m_Value(), m_Value(M
),
634 Propagate
= setShapeInfo(Inst
, {M
, N
});
635 } else if (match(Inst
, m_Store(m_Value(MatrixA
), m_Value()))) {
636 auto OpShape
= ShapeMap
.find(MatrixA
);
637 if (OpShape
!= ShapeMap
.end())
638 setShapeInfo(Inst
, OpShape
->second
);
640 } else if (isUniformShape(Inst
)) {
641 // Find the first operand that has a known shape and use that.
642 for (auto &Op
: Inst
->operands()) {
643 auto OpShape
= ShapeMap
.find(Op
.get());
644 if (OpShape
!= ShapeMap
.end()) {
645 Propagate
|= setShapeInfo(Inst
, OpShape
->second
);
652 NewWorkList
.push_back(Inst
);
653 for (auto *User
: Inst
->users())
654 if (ShapeMap
.count(User
) == 0)
655 WorkList
.push_back(cast
<Instruction
>(User
));
662 /// Propagate the shape to operands of instructions with shape information.
663 /// \p Worklist contains the instruction for which we already know the shape.
664 SmallVector
<Instruction
*, 32>
665 propagateShapeBackward(SmallVectorImpl
<Instruction
*> &WorkList
) {
666 SmallVector
<Instruction
*, 32> NewWorkList
;
668 auto pushInstruction
= [](Value
*V
,
669 SmallVectorImpl
<Instruction
*> &WorkList
) {
670 Instruction
*I
= dyn_cast
<Instruction
>(V
);
672 WorkList
.push_back(I
);
674 // Pop an element with known shape. Traverse the operands, if their shape
675 // derives from the result shape and is unknown, add it and add them to the
677 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
678 while (!WorkList
.empty()) {
679 Value
*V
= WorkList
.pop_back_val();
681 size_t BeforeProcessingV
= WorkList
.size();
682 if (!isa
<Instruction
>(V
))
690 if (match(V
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
691 m_Value(MatrixA
), m_Value(MatrixB
), m_Value(M
),
692 m_Value(N
), m_Value(K
)))) {
693 if (setShapeInfo(MatrixA
, {M
, N
}))
694 pushInstruction(MatrixA
, WorkList
);
696 if (setShapeInfo(MatrixB
, {N
, K
}))
697 pushInstruction(MatrixB
, WorkList
);
699 } else if (match(V
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
700 m_Value(MatrixA
), m_Value(M
), m_Value(N
)))) {
702 if (setShapeInfo(MatrixA
, {M
, N
}))
703 pushInstruction(MatrixA
, WorkList
);
704 } else if (match(V
, m_Intrinsic
<Intrinsic::matrix_column_major_store
>(
705 m_Value(MatrixA
), m_Value(), m_Value(), m_Value(),
706 m_Value(M
), m_Value(N
)))) {
707 if (setShapeInfo(MatrixA
, {M
, N
})) {
708 pushInstruction(MatrixA
, WorkList
);
710 } else if (isa
<LoadInst
>(V
) ||
711 match(V
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>())) {
712 // Nothing to do, no matrix input.
713 } else if (isa
<StoreInst
>(V
)) {
714 // Nothing to do. We forward-propagated to this so we would just
715 // backward propagate to an instruction with an already known shape.
716 } else if (isUniformShape(V
)) {
717 // Propagate to all operands.
718 ShapeInfo Shape
= ShapeMap
[V
];
719 for (Use
&U
: cast
<Instruction
>(V
)->operands()) {
720 if (setShapeInfo(U
.get(), Shape
))
721 pushInstruction(U
.get(), WorkList
);
724 // After we discovered new shape info for new instructions in the
725 // worklist, we use their users as seeds for the next round of forward
727 for (size_t I
= BeforeProcessingV
; I
!= WorkList
.size(); I
++)
728 for (User
*U
: WorkList
[I
]->users())
729 if (isa
<Instruction
>(U
) && V
!= U
)
730 NewWorkList
.push_back(cast
<Instruction
>(U
));
735 /// (Op0 op Op1)^T -> Op0^T op Op1^T
736 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
737 /// them on both sides of \p Operation.
738 Instruction
*distributeTransposes(
739 Value
*Op0
, ShapeInfo Shape0
, Value
*Op1
, ShapeInfo Shape1
,
740 MatrixBuilder
&Builder
,
741 function_ref
<Instruction
*(Value
*, ShapeInfo
, Value
*, ShapeInfo
)>
743 Value
*T0
= Builder
.CreateMatrixTranspose(
744 Op0
, Shape0
.NumRows
, Shape0
.NumColumns
, Op0
->getName() + "_t");
745 // We are being run after shape prop, add shape for newly created
746 // instructions so that we lower them later.
747 setShapeInfo(T0
, Shape0
.t());
748 Value
*T1
= Builder
.CreateMatrixTranspose(
749 Op1
, Shape1
.NumRows
, Shape1
.NumColumns
, Op1
->getName() + "_t");
750 setShapeInfo(T1
, Shape1
.t());
751 return Operation(T0
, Shape0
.t(), T1
, Shape1
.t());
754 void updateShapeAndReplaceAllUsesWith(Instruction
&Old
, Value
*New
) {
755 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
756 // with New. We should only add New it it supportsShapeInfo so we insert
757 // it conditionally instead.
758 auto S
= ShapeMap
.find(&Old
);
759 if (S
!= ShapeMap
.end()) {
761 if (supportsShapeInfo(New
))
762 ShapeMap
.insert({New
, S
->second
});
764 Old
.replaceAllUsesWith(New
);
767 /// Sink a top-level transpose inside matmuls and adds.
768 /// This creates and erases instructions as needed, and returns the newly
769 /// created instruction while updating the iterator to avoid invalidation. If
770 /// this returns nullptr, no new instruction was created.
771 Instruction
*sinkTranspose(Instruction
&I
, BasicBlock::reverse_iterator
&II
) {
772 BasicBlock
&BB
= *I
.getParent();
774 MatrixBuilder
Builder(IB
);
776 Value
*TA
, *TAMA
, *TAMB
;
777 ConstantInt
*R
, *K
, *C
;
778 if (!match(&I
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
779 m_Value(TA
), m_ConstantInt(R
), m_ConstantInt(C
))))
782 // Transpose of a transpose is a nop
784 if (match(TA
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(TATA
)))) {
785 updateShapeAndReplaceAllUsesWith(I
, TATA
);
786 eraseFromParentAndMove(&I
, II
, BB
);
787 eraseFromParentAndMove(TA
, II
, BB
);
793 updateShapeAndReplaceAllUsesWith(I
, TA
);
794 eraseFromParentAndMove(&I
, II
, BB
);
798 // (A * B)^t -> B^t * A^t
800 if (match(TA
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
801 m_Value(TAMA
), m_Value(TAMB
), m_ConstantInt(R
),
802 m_ConstantInt(K
), m_ConstantInt(C
)))) {
803 auto NewInst
= distributeTransposes(
804 TAMB
, {K
, C
}, TAMA
, {R
, K
}, Builder
,
805 [&](Value
*T0
, ShapeInfo Shape0
, Value
*T1
, ShapeInfo Shape1
) {
806 return Builder
.CreateMatrixMultiply(T0
, T1
, Shape0
.NumRows
,
808 Shape1
.NumColumns
, "mmul");
810 updateShapeAndReplaceAllUsesWith(I
, NewInst
);
811 eraseFromParentAndMove(&I
, II
, BB
);
812 eraseFromParentAndMove(TA
, II
, BB
);
816 // Same as above, but with a mul, which occurs when multiplied
818 // (A * k)^t -> A^t * k
820 if (match(TA
, m_AnyMul(m_Value(TAMA
), m_Value(TAMB
))) &&
821 (isSplat(TAMA
) || isSplat(TAMB
))) {
822 IRBuilder
<> LocalBuilder(&I
);
823 // We know that the transposed operand is of shape RxC.
824 // An when multiplied with a scalar, the shape is preserved.
825 auto NewInst
= distributeTransposes(
826 TAMA
, {R
, C
}, TAMB
, {R
, C
}, Builder
,
827 [&](Value
*T0
, ShapeInfo Shape0
, Value
*T1
, ShapeInfo Shape1
) {
828 bool IsFP
= I
.getType()->isFPOrFPVectorTy();
829 auto *Mul
= IsFP
? LocalBuilder
.CreateFMul(T0
, T1
, "mmul")
830 : LocalBuilder
.CreateMul(T0
, T1
, "mmul");
831 auto *Result
= cast
<Instruction
>(Mul
);
832 setShapeInfo(Result
, Shape0
);
835 updateShapeAndReplaceAllUsesWith(I
, NewInst
);
836 eraseFromParentAndMove(&I
, II
, BB
);
837 eraseFromParentAndMove(TA
, II
, BB
);
841 // (A + B)^t -> A^t + B^t
843 if (match(TA
, m_AnyAdd(m_Value(TAMA
), m_Value(TAMB
)))) {
844 IRBuilder
<> LocalBuilder(&I
);
845 auto NewInst
= distributeTransposes(
846 TAMA
, {R
, C
}, TAMB
, {R
, C
}, Builder
,
847 [&](Value
*T0
, ShapeInfo Shape0
, Value
*T1
, ShapeInfo Shape1
) {
848 bool IsFP
= I
.getType()->isFPOrFPVectorTy();
849 auto *Add
= IsFP
? LocalBuilder
.CreateFAdd(T0
, T1
, "madd")
850 : LocalBuilder
.CreateAdd(T0
, T1
, "madd");
852 auto *Result
= cast
<Instruction
>(Add
);
853 setShapeInfo(Result
, Shape0
);
856 updateShapeAndReplaceAllUsesWith(I
, NewInst
);
857 eraseFromParentAndMove(&I
, II
, BB
);
858 eraseFromParentAndMove(TA
, II
, BB
);
865 void liftTranspose(Instruction
&I
) {
866 // Erase dead Instructions after lifting transposes from binops.
867 auto CleanupBinOp
= [](Instruction
&T
, Value
*A
, Value
*B
) {
871 cast
<Instruction
>(A
)->eraseFromParent();
872 if (A
!= B
&& B
->use_empty())
873 cast
<Instruction
>(B
)->eraseFromParent();
876 Value
*A
, *B
, *AT
, *BT
;
877 ConstantInt
*R
, *K
, *C
;
878 // A^t * B ^t -> (B * A)^t
879 if (match(&I
, m_Intrinsic
<Intrinsic::matrix_multiply
>(
880 m_Value(A
), m_Value(B
), m_ConstantInt(R
),
881 m_ConstantInt(K
), m_ConstantInt(C
))) &&
882 match(A
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(AT
))) &&
883 match(B
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value((BT
))))) {
885 MatrixBuilder
Builder(IB
);
886 Value
*M
= Builder
.CreateMatrixMultiply(
887 BT
, AT
, C
->getZExtValue(), K
->getZExtValue(), R
->getZExtValue());
888 setShapeInfo(M
, {C
, R
});
889 Instruction
*NewInst
= Builder
.CreateMatrixTranspose(M
, C
->getZExtValue(),
891 updateShapeAndReplaceAllUsesWith(I
, NewInst
);
892 CleanupBinOp(I
, A
, B
);
894 // A^t + B ^t -> (A + B)^t
895 else if (match(&I
, m_FAdd(m_Value(A
), m_Value(B
))) &&
896 match(A
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
897 m_Value(AT
), m_ConstantInt(R
), m_ConstantInt(C
))) &&
898 match(B
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
899 m_Value(BT
), m_ConstantInt(R
), m_ConstantInt(C
)))) {
900 IRBuilder
<> Builder(&I
);
901 Value
*Add
= cast
<Instruction
>(Builder
.CreateFAdd(AT
, BT
, "mfadd"));
902 setShapeInfo(Add
, {C
, R
});
903 MatrixBuilder
MBuilder(Builder
);
904 Instruction
*NewInst
= MBuilder
.CreateMatrixTranspose(
905 Add
, C
->getZExtValue(), R
->getZExtValue(), "mfadd_t");
906 updateShapeAndReplaceAllUsesWith(I
, NewInst
);
907 CleanupBinOp(I
, A
, B
);
911 /// Try moving transposes in order to fold them away or into multiplies.
912 void optimizeTransposes() {
913 // First sink all transposes inside matmuls and adds, hoping that we end up
914 // with NN, NT or TN variants.
915 for (BasicBlock
&BB
: reverse(Func
)) {
916 for (auto II
= BB
.rbegin(); II
!= BB
.rend();) {
917 Instruction
&I
= *II
;
918 // We may remove II. By default continue on the next/prev instruction.
920 if (Instruction
*NewInst
= sinkTranspose(I
, II
))
921 II
= std::next(BasicBlock::reverse_iterator(NewInst
));
925 // If we have a TT matmul or a TT add, lift the transpose. We may be able
926 // to fold into consuming multiply or add.
927 for (BasicBlock
&BB
: Func
) {
928 for (Instruction
&I
: llvm::make_early_inc_range(BB
)) {
935 SmallVector
<Instruction
*, 32> WorkList
;
937 // Initially only the shape of matrix intrinsics is known.
938 // Initialize the work list with ops carrying shape information.
939 for (BasicBlock
&BB
: Func
)
940 for (Instruction
&Inst
: BB
) {
941 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(&Inst
);
945 switch (II
->getIntrinsicID()) {
946 case Intrinsic::matrix_multiply
:
947 case Intrinsic::matrix_transpose
:
948 case Intrinsic::matrix_column_major_load
:
949 case Intrinsic::matrix_column_major_store
:
950 WorkList
.push_back(&Inst
);
957 // Avoid unnecessary work if there are no matrix intrinsics in the function.
958 if (WorkList
.empty())
961 // Propagate shapes until nothing changes any longer.
962 while (!WorkList
.empty()) {
963 WorkList
= propagateShapeForward(WorkList
);
964 WorkList
= propagateShapeBackward(WorkList
);
968 optimizeTransposes();
969 if (PrintAfterTransposeOpt
) {
970 dbgs() << "Dump after matrix transpose optimization:\n";
975 bool Changed
= false;
976 SmallVector
<CallInst
*, 16> MaybeFusableInsts
;
977 SmallVector
<Instruction
*, 16> MatrixInsts
;
979 // First, collect all instructions with shape information and candidates for
980 // fusion (currently only matrix multiplies).
981 ReversePostOrderTraversal
<Function
*> RPOT(&Func
);
982 for (auto *BB
: RPOT
)
983 for (Instruction
&I
: *BB
) {
984 if (ShapeMap
.find(&I
) == ShapeMap
.end())
986 if (match(&I
, m_Intrinsic
<Intrinsic::matrix_multiply
>()))
987 MaybeFusableInsts
.push_back(cast
<CallInst
>(&I
));
988 MatrixInsts
.push_back(&I
);
991 // Second, try to lower any dot products
992 SmallPtrSet
<Instruction
*, 16> FusedInsts
;
993 for (CallInst
*CI
: MaybeFusableInsts
)
994 lowerDotProduct(CI
, FusedInsts
, getFastMathFlags(CI
));
996 // Third, try to fuse candidates.
997 for (CallInst
*CI
: MaybeFusableInsts
)
998 LowerMatrixMultiplyFused(CI
, FusedInsts
);
1000 Changed
= !FusedInsts
.empty();
1002 // Fourth, lower remaining instructions with shape information.
1003 for (Instruction
*Inst
: MatrixInsts
) {
1004 if (FusedInsts
.count(Inst
))
1007 IRBuilder
<> Builder(Inst
);
1009 if (CallInst
*CInst
= dyn_cast
<CallInst
>(Inst
))
1010 Changed
|= VisitCallInst(CInst
);
1014 if (auto *BinOp
= dyn_cast
<BinaryOperator
>(Inst
))
1015 Changed
|= VisitBinaryOperator(BinOp
);
1016 if (auto *UnOp
= dyn_cast
<UnaryOperator
>(Inst
))
1017 Changed
|= VisitUnaryOperator(UnOp
);
1018 if (match(Inst
, m_Load(m_Value(Op1
))))
1019 Changed
|= VisitLoad(cast
<LoadInst
>(Inst
), Op1
, Builder
);
1020 else if (match(Inst
, m_Store(m_Value(Op1
), m_Value(Op2
))))
1021 Changed
|= VisitStore(cast
<StoreInst
>(Inst
), Op1
, Op2
, Builder
);
1025 RemarkGenerator
RemarkGen(Inst2ColumnMatrix
, *ORE
, Func
);
1026 RemarkGen
.emitRemarks();
1029 // Delete the instructions backwards, as it has a reduced likelihood of
1030 // having to update as many def-use and use-def chains.
1032 // Because we add to ToRemove during fusion we can't guarantee that defs
1033 // are before uses. Change uses to poison temporarily as these should get
1036 // For verification, we keep track of where we changed uses to poison in
1037 // PoisonedInsts and then check that we in fact remove them.
1038 SmallSet
<Instruction
*, 16> PoisonedInsts
;
1039 for (auto *Inst
: reverse(ToRemove
)) {
1040 for (Use
&U
: llvm::make_early_inc_range(Inst
->uses())) {
1041 if (auto *Poisoned
= dyn_cast
<Instruction
>(U
.getUser()))
1042 PoisonedInsts
.insert(Poisoned
);
1043 U
.set(PoisonValue::get(Inst
->getType()));
1045 Inst
->eraseFromParent();
1046 PoisonedInsts
.erase(Inst
);
1048 if (!PoisonedInsts
.empty()) {
1049 // If we didn't remove all poisoned instructions, it's a hard error.
1050 dbgs() << "Poisoned but present instructions:\n";
1051 for (auto *I
: PoisonedInsts
)
1052 dbgs() << *I
<< "\n";
1053 llvm_unreachable("Poisoned but instruction not removed");
1059 /// Replace intrinsic calls
1060 bool VisitCallInst(CallInst
*Inst
) {
1061 if (!Inst
->getCalledFunction() || !Inst
->getCalledFunction()->isIntrinsic())
1064 switch (Inst
->getCalledFunction()->getIntrinsicID()) {
1065 case Intrinsic::matrix_multiply
:
1066 LowerMultiply(Inst
);
1068 case Intrinsic::matrix_transpose
:
1069 LowerTranspose(Inst
);
1071 case Intrinsic::matrix_column_major_load
:
1072 LowerColumnMajorLoad(Inst
);
1074 case Intrinsic::matrix_column_major_store
:
1075 LowerColumnMajorStore(Inst
);
1083 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1084 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1085 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1086 /// non-ConstantInt strides, return the common alignment of the initial
1087 /// alignment and the element size in bytes.
1088 Align
getAlignForIndex(unsigned Idx
, Value
*Stride
, Type
*ElementTy
,
1089 MaybeAlign A
) const {
1090 Align InitialAlign
= DL
.getValueOrABITypeAlignment(A
, ElementTy
);
1092 return InitialAlign
;
1094 TypeSize ElementSizeInBits
= DL
.getTypeSizeInBits(ElementTy
);
1095 if (auto *ConstStride
= dyn_cast
<ConstantInt
>(Stride
)) {
1096 uint64_t StrideInBytes
=
1097 ConstStride
->getZExtValue() * ElementSizeInBits
/ 8;
1098 return commonAlignment(InitialAlign
, Idx
* StrideInBytes
);
1100 return commonAlignment(InitialAlign
, ElementSizeInBits
/ 8);
1103 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1105 MatrixTy
loadMatrix(Type
*Ty
, Value
*Ptr
, MaybeAlign MAlign
, Value
*Stride
,
1106 bool IsVolatile
, ShapeInfo Shape
, IRBuilder
<> &Builder
) {
1107 auto *VType
= cast
<VectorType
>(Ty
);
1108 Type
*EltTy
= VType
->getElementType();
1109 Type
*VecTy
= FixedVectorType::get(EltTy
, Shape
.getStride());
1110 Value
*EltPtr
= Ptr
;
1112 for (unsigned I
= 0, E
= Shape
.getNumVectors(); I
< E
; ++I
) {
1113 Value
*GEP
= computeVectorAddr(
1114 EltPtr
, Builder
.getIntN(Stride
->getType()->getScalarSizeInBits(), I
),
1115 Stride
, Shape
.getStride(), EltTy
, Builder
);
1116 Value
*Vector
= Builder
.CreateAlignedLoad(
1117 VecTy
, GEP
, getAlignForIndex(I
, Stride
, EltTy
, MAlign
),
1118 IsVolatile
, "col.load");
1120 Result
.addVector(Vector
);
1122 return Result
.addNumLoads(getNumOps(Result
.getVectorTy()) *
1123 Result
.getNumVectors());
1126 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1127 /// starting at \p MatrixPtr[I][J].
1128 MatrixTy
loadMatrix(Value
*MatrixPtr
, MaybeAlign Align
, bool IsVolatile
,
1129 ShapeInfo MatrixShape
, Value
*I
, Value
*J
,
1130 ShapeInfo ResultShape
, Type
*EltTy
,
1131 IRBuilder
<> &Builder
) {
1133 Value
*Offset
= Builder
.CreateAdd(
1134 Builder
.CreateMul(J
, Builder
.getInt64(MatrixShape
.getStride())), I
);
1136 Value
*TileStart
= Builder
.CreateGEP(EltTy
, MatrixPtr
, Offset
);
1137 auto *TileTy
= FixedVectorType::get(EltTy
, ResultShape
.NumRows
*
1138 ResultShape
.NumColumns
);
1140 return loadMatrix(TileTy
, TileStart
, Align
,
1141 Builder
.getInt64(MatrixShape
.getStride()), IsVolatile
,
1142 ResultShape
, Builder
);
1145 /// Lower a load instruction with shape information.
1146 void LowerLoad(Instruction
*Inst
, Value
*Ptr
, MaybeAlign Align
, Value
*Stride
,
1147 bool IsVolatile
, ShapeInfo Shape
) {
1148 IRBuilder
<> Builder(Inst
);
1149 finalizeLowering(Inst
,
1150 loadMatrix(Inst
->getType(), Ptr
, Align
, Stride
, IsVolatile
,
1155 /// Lowers llvm.matrix.column.major.load.
1157 /// The intrinsic loads a matrix from memory using a stride between columns.
1158 void LowerColumnMajorLoad(CallInst
*Inst
) {
1159 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1160 "Intrinsic only supports column-major layout!");
1161 Value
*Ptr
= Inst
->getArgOperand(0);
1162 Value
*Stride
= Inst
->getArgOperand(1);
1163 LowerLoad(Inst
, Ptr
, Inst
->getParamAlign(0), Stride
,
1164 cast
<ConstantInt
>(Inst
->getArgOperand(2))->isOne(),
1165 {Inst
->getArgOperand(3), Inst
->getArgOperand(4)});
1168 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1169 /// MatrixPtr[I][J].
1170 void storeMatrix(const MatrixTy
&StoreVal
, Value
*MatrixPtr
,
1171 MaybeAlign MAlign
, bool IsVolatile
, ShapeInfo MatrixShape
,
1172 Value
*I
, Value
*J
, Type
*EltTy
, IRBuilder
<> &Builder
) {
1173 Value
*Offset
= Builder
.CreateAdd(
1174 Builder
.CreateMul(J
, Builder
.getInt64(MatrixShape
.getStride())), I
);
1176 Value
*TileStart
= Builder
.CreateGEP(EltTy
, MatrixPtr
, Offset
);
1177 auto *TileTy
= FixedVectorType::get(EltTy
, StoreVal
.getNumRows() *
1178 StoreVal
.getNumColumns());
1180 storeMatrix(TileTy
, StoreVal
, TileStart
, MAlign
,
1181 Builder
.getInt64(MatrixShape
.getStride()), IsVolatile
, Builder
);
1184 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1186 MatrixTy
storeMatrix(Type
*Ty
, MatrixTy StoreVal
, Value
*Ptr
,
1187 MaybeAlign MAlign
, Value
*Stride
, bool IsVolatile
,
1188 IRBuilder
<> &Builder
) {
1189 auto VType
= cast
<VectorType
>(Ty
);
1190 Value
*EltPtr
= Ptr
;
1191 for (auto Vec
: enumerate(StoreVal
.vectors())) {
1192 Value
*GEP
= computeVectorAddr(
1194 Builder
.getIntN(Stride
->getType()->getScalarSizeInBits(),
1196 Stride
, StoreVal
.getStride(), VType
->getElementType(), Builder
);
1197 Builder
.CreateAlignedStore(Vec
.value(), GEP
,
1198 getAlignForIndex(Vec
.index(), Stride
,
1199 VType
->getElementType(),
1203 return MatrixTy().addNumStores(getNumOps(StoreVal
.getVectorTy()) *
1204 StoreVal
.getNumVectors());
1207 /// Lower a store instruction with shape information.
1208 void LowerStore(Instruction
*Inst
, Value
*Matrix
, Value
*Ptr
, MaybeAlign A
,
1209 Value
*Stride
, bool IsVolatile
, ShapeInfo Shape
) {
1210 IRBuilder
<> Builder(Inst
);
1211 auto StoreVal
= getMatrix(Matrix
, Shape
, Builder
);
1212 finalizeLowering(Inst
,
1213 storeMatrix(Matrix
->getType(), StoreVal
, Ptr
, A
, Stride
,
1214 IsVolatile
, Builder
),
1218 /// Lowers llvm.matrix.column.major.store.
1220 /// The intrinsic store a matrix back memory using a stride between columns.
1221 void LowerColumnMajorStore(CallInst
*Inst
) {
1222 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1223 "Intrinsic only supports column-major layout!");
1224 Value
*Matrix
= Inst
->getArgOperand(0);
1225 Value
*Ptr
= Inst
->getArgOperand(1);
1226 Value
*Stride
= Inst
->getArgOperand(2);
1227 LowerStore(Inst
, Matrix
, Ptr
, Inst
->getParamAlign(1), Stride
,
1228 cast
<ConstantInt
>(Inst
->getArgOperand(3))->isOne(),
1229 {Inst
->getArgOperand(4), Inst
->getArgOperand(5)});
1232 // Set elements I..I+NumElts-1 to Block
1233 Value
*insertVector(Value
*Col
, unsigned I
, Value
*Block
,
1234 IRBuilder
<> &Builder
) {
1236 // First, bring Block to the same size as Col
1237 unsigned BlockNumElts
=
1238 cast
<FixedVectorType
>(Block
->getType())->getNumElements();
1239 unsigned NumElts
= cast
<FixedVectorType
>(Col
->getType())->getNumElements();
1240 assert(NumElts
>= BlockNumElts
&& "Too few elements for current block");
1242 Block
= Builder
.CreateShuffleVector(
1243 Block
, createSequentialMask(0, BlockNumElts
, NumElts
- BlockNumElts
));
1245 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1247 SmallVector
<int, 16> Mask
;
1249 for (i
= 0; i
< I
; i
++)
1252 unsigned VecNumElts
=
1253 cast
<FixedVectorType
>(Col
->getType())->getNumElements();
1254 for (; i
< I
+ BlockNumElts
; i
++)
1255 Mask
.push_back(i
- I
+ VecNumElts
);
1257 for (; i
< VecNumElts
; i
++)
1260 return Builder
.CreateShuffleVector(Col
, Block
, Mask
);
1263 Value
*createMulAdd(Value
*Sum
, Value
*A
, Value
*B
, bool UseFPOp
,
1264 IRBuilder
<> &Builder
, bool AllowContraction
,
1265 unsigned &NumComputeOps
) {
1266 NumComputeOps
+= getNumOps(A
->getType());
1268 return UseFPOp
? Builder
.CreateFMul(A
, B
) : Builder
.CreateMul(A
, B
);
1271 if (AllowContraction
) {
1272 // Use fmuladd for floating point operations and let the backend decide
1273 // if that's profitable.
1274 Function
*FMulAdd
= Intrinsic::getDeclaration(
1275 Func
.getParent(), Intrinsic::fmuladd
, A
->getType());
1276 return Builder
.CreateCall(FMulAdd
, {A
, B
, Sum
});
1278 NumComputeOps
+= getNumOps(A
->getType());
1279 Value
*Mul
= Builder
.CreateFMul(A
, B
);
1280 return Builder
.CreateFAdd(Sum
, Mul
);
1283 NumComputeOps
+= getNumOps(A
->getType());
1284 Value
*Mul
= Builder
.CreateMul(A
, B
);
1285 return Builder
.CreateAdd(Sum
, Mul
);
1288 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1289 /// users with shape information, there's nothing to do: they will use the
1290 /// cached value when they are lowered. For other users, \p Matrix is
1291 /// flattened and the uses are updated to use it. Also marks \p Inst for
1293 void finalizeLowering(Instruction
*Inst
, MatrixTy Matrix
,
1294 IRBuilder
<> &Builder
) {
1295 auto inserted
= Inst2ColumnMatrix
.insert(std::make_pair(Inst
, Matrix
));
1297 assert(inserted
.second
&& "multiple matrix lowering mapping");
1299 ToRemove
.push_back(Inst
);
1300 Value
*Flattened
= nullptr;
1301 for (Use
&U
: llvm::make_early_inc_range(Inst
->uses())) {
1302 if (ShapeMap
.find(U
.getUser()) == ShapeMap
.end()) {
1304 Flattened
= Matrix
.embedInVector(Builder
);
1310 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1311 /// vectors Lowers to vector reduction add instead of sequential add if
1312 /// reassocation is enabled.
1313 void lowerDotProduct(CallInst
*MatMul
,
1314 SmallPtrSet
<Instruction
*, 16> &FusedInsts
,
1315 FastMathFlags FMF
) {
1316 if (FusedInsts
.contains(MatMul
) ||
1317 MatrixLayout
!= MatrixLayoutTy::ColumnMajor
)
1319 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1320 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1322 if (LShape
.NumRows
!= 1 || RShape
.NumColumns
!= 1) // not a dot product
1325 Value
*LHS
= MatMul
->getArgOperand(0);
1326 Value
*RHS
= MatMul
->getArgOperand(1);
1328 Type
*ElementType
= cast
<VectorType
>(LHS
->getType())->getElementType();
1329 bool IsIntVec
= ElementType
->isIntegerTy();
1331 // Floating point reductions require reassocation.
1332 if (!IsIntVec
&& !FMF
.allowReassoc())
1335 auto CanBeFlattened
= [this](Value
*Op
) {
1336 if (match(Op
, m_BinOp()) && ShapeMap
.find(Op
) != ShapeMap
.end())
1339 Op
, m_OneUse(m_CombineOr(
1341 m_CombineOr(m_Intrinsic
<Intrinsic::matrix_transpose
>(),
1342 m_Intrinsic
<Intrinsic::matrix_column_major_load
>(
1343 m_Value(), m_SpecificInt(1))))));
1345 // Returns the cost benefit of using \p Op with the dot product lowering. If
1346 // the returned cost is < 0, the argument is cheaper to use in the
1347 // dot-product lowering.
1348 auto GetCostForArg
= [this, &CanBeFlattened
](Value
*Op
, unsigned N
) {
1349 if (!isa
<Instruction
>(Op
))
1350 return InstructionCost(0);
1352 FixedVectorType
*VecTy
= cast
<FixedVectorType
>(Op
->getType());
1353 Type
*EltTy
= VecTy
->getElementType();
1355 if (!CanBeFlattened(Op
)) {
1356 InstructionCost
EmbedCost(0);
1357 // Roughly estimate the cost for embedding the columns into a vector.
1358 for (unsigned I
= 1; I
< N
; ++I
)
1360 TTI
.getShuffleCost(TTI::SK_Splice
, FixedVectorType::get(EltTy
, 1),
1361 std::nullopt
, TTI::TCK_RecipThroughput
);
1365 if (match(Op
, m_BinOp()) && ShapeMap
.find(Op
) != ShapeMap
.end()) {
1366 InstructionCost OriginalCost
=
1367 TTI
.getArithmeticInstrCost(cast
<Instruction
>(Op
)->getOpcode(),
1370 InstructionCost NewCost
= TTI
.getArithmeticInstrCost(
1371 cast
<Instruction
>(Op
)->getOpcode(), VecTy
);
1372 return NewCost
- OriginalCost
;
1375 if (match(Op
, m_Intrinsic
<Intrinsic::matrix_transpose
>())) {
1376 // The transpose can be skipped for the dot product lowering, roughly
1377 // estimate the savings as the cost of embedding the columns in a
1379 InstructionCost
EmbedCost(0);
1380 for (unsigned I
= 1; I
< N
; ++I
)
1382 TTI
.getShuffleCost(TTI::SK_Splice
, FixedVectorType::get(EltTy
, 1),
1383 std::nullopt
, TTI::TCK_RecipThroughput
);
1389 return InstructionCost(0);
1391 return TTI
.getMemoryOpCost(Instruction::Load
, VecTy
, Align(1), 0) -
1392 N
* TTI
.getMemoryOpCost(Instruction::Load
, EltTy
, Align(1), 0);
1394 auto LHSCost
= GetCostForArg(LHS
, LShape
.NumColumns
);
1396 // We compare the costs of a vector.reduce.add to sequential add.
1397 int AddOpCode
= IsIntVec
? Instruction::Add
: Instruction::FAdd
;
1398 int MulOpCode
= IsIntVec
? Instruction::Mul
: Instruction::FMul
;
1399 InstructionCost ReductionCost
=
1400 TTI
.getArithmeticReductionCost(
1401 AddOpCode
, cast
<VectorType
>(LHS
->getType()),
1402 IsIntVec
? std::nullopt
: std::optional(FMF
)) +
1403 TTI
.getArithmeticInstrCost(MulOpCode
, LHS
->getType());
1404 InstructionCost SequentialAddCost
=
1405 TTI
.getArithmeticInstrCost(AddOpCode
, ElementType
) *
1406 (LShape
.NumColumns
- 1) +
1407 TTI
.getArithmeticInstrCost(MulOpCode
, ElementType
) *
1408 (LShape
.NumColumns
);
1409 if ((LHSCost
+ ReductionCost
- SequentialAddCost
) > InstructionCost(0))
1412 FusedInsts
.insert(MatMul
);
1413 IRBuilder
<> Builder(MatMul
);
1414 auto FlattenArg
= [&Builder
, &FusedInsts
, &CanBeFlattened
,
1415 this](Value
*Op
) -> Value
* {
1416 // Matmul must be the only user of loads because we don't use LowerLoad
1417 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1418 // instead of single vector load).
1419 if (!CanBeFlattened(Op
))
1422 if (match(Op
, m_BinOp()) && ShapeMap
.find(Op
) != ShapeMap
.end()) {
1423 ShapeMap
[Op
] = ShapeMap
[Op
].t();
1427 FusedInsts
.insert(cast
<Instruction
>(Op
));
1428 // If vector uses the builtin load, lower to a LoadInst
1430 if (match(Op
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>(
1432 auto *NewLoad
= Builder
.CreateLoad(Op
->getType(), Arg
);
1433 Op
->replaceAllUsesWith(NewLoad
);
1434 cast
<Instruction
>(Op
)->eraseFromParent();
1436 } else if (match(Op
, m_Intrinsic
<Intrinsic::matrix_transpose
>(
1438 ToRemove
.push_back(cast
<Instruction
>(Op
));
1444 LHS
= FlattenArg(LHS
);
1446 // Insert mul/fmul and llvm.vector.reduce.fadd
1448 IsIntVec
? Builder
.CreateMul(LHS
, RHS
) : Builder
.CreateFMul(LHS
, RHS
);
1452 Result
= Builder
.CreateAddReduce(Mul
);
1454 Result
= Builder
.CreateFAddReduce(
1455 ConstantFP::get(cast
<VectorType
>(LHS
->getType())->getElementType(),
1458 cast
<Instruction
>(Result
)->setFastMathFlags(FMF
);
1461 // pack scalar back into a matrix and then replace matmul inst
1462 Result
= Builder
.CreateInsertElement(PoisonValue::get(MatMul
->getType()),
1463 Result
, uint64_t(0));
1464 MatMul
->replaceAllUsesWith(Result
);
1465 FusedInsts
.insert(MatMul
);
1466 ToRemove
.push_back(MatMul
);
1469 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1472 /// We can fold a transpose into the operand that is used to extract scalars.
1473 /// This is the first operands with row-major and the second with
1474 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1475 /// operand is transposed.
1476 void emitMatrixMultiply(MatrixTy
&Result
, const MatrixTy
&A
,
1477 const MatrixTy
&B
, IRBuilder
<> &Builder
, bool IsTiled
,
1478 bool IsScalarMatrixTransposed
, FastMathFlags FMF
) {
1479 const unsigned VF
= std::max
<unsigned>(
1480 TTI
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector
)
1482 Result
.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1484 unsigned R
= Result
.getNumRows();
1485 unsigned C
= Result
.getNumColumns();
1486 unsigned M
= A
.getNumColumns();
1488 bool IsFP
= Result
.getElementType()->isFloatingPointTy();
1489 assert(A
.isColumnMajor() == B
.isColumnMajor() &&
1490 Result
.isColumnMajor() == A
.isColumnMajor() &&
1491 "operands must agree on matrix layout");
1492 unsigned NumComputeOps
= 0;
1494 Builder
.setFastMathFlags(FMF
);
1496 if (A
.isColumnMajor()) {
1497 // Multiply columns from the first operand with scalars from the second
1498 // operand. Then move along the K axes and accumulate the columns. With
1499 // this the adds can be vectorized without reassociation.
1500 for (unsigned J
= 0; J
< C
; ++J
) {
1501 unsigned BlockSize
= VF
;
1502 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1503 bool isSumZero
= isa
<ConstantAggregateZero
>(Result
.getColumn(J
));
1505 for (unsigned I
= 0; I
< R
; I
+= BlockSize
) {
1506 // Gradually lower the vectorization factor to cover the remainder.
1507 while (I
+ BlockSize
> R
)
1510 Value
*Sum
= IsTiled
? Result
.extractVector(I
, J
, BlockSize
, Builder
)
1512 for (unsigned K
= 0; K
< M
; ++K
) {
1513 Value
*L
= A
.extractVector(I
, K
, BlockSize
, Builder
);
1514 Value
*RH
= Builder
.CreateExtractElement(
1515 B
.getColumn(IsScalarMatrixTransposed
? K
: J
),
1516 IsScalarMatrixTransposed
? J
: K
);
1517 Value
*Splat
= Builder
.CreateVectorSplat(BlockSize
, RH
, "splat");
1519 createMulAdd(isSumZero
&& K
== 0 ? nullptr : Sum
, L
, Splat
,
1520 IsFP
, Builder
, FMF
.allowContract(), NumComputeOps
);
1523 insertVector(Result
.getVector(J
), I
, Sum
, Builder
));
1527 // Multiply rows from the second operand with scalars from the first
1528 // operand. Then move along the K axes and accumulate the rows. With this
1529 // the adds can be vectorized without reassociation.
1530 for (unsigned I
= 0; I
< R
; ++I
) {
1531 unsigned BlockSize
= VF
;
1532 bool isSumZero
= isa
<ConstantAggregateZero
>(Result
.getRow(I
));
1533 for (unsigned J
= 0; J
< C
; J
+= BlockSize
) {
1534 // Gradually lower the vectorization factor to cover the remainder.
1535 while (J
+ BlockSize
> C
)
1538 Value
*Sum
= nullptr;
1539 for (unsigned K
= 0; K
< M
; ++K
) {
1540 Value
*R
= B
.extractVector(K
, J
, BlockSize
, Builder
);
1541 Value
*LH
= Builder
.CreateExtractElement(
1542 A
.getVector(IsScalarMatrixTransposed
? K
: I
),
1543 IsScalarMatrixTransposed
? I
: K
);
1544 Value
*Splat
= Builder
.CreateVectorSplat(BlockSize
, LH
, "splat");
1546 createMulAdd(isSumZero
&& K
== 0 ? nullptr : Sum
, Splat
, R
,
1547 IsFP
, Builder
, FMF
.allowContract(), NumComputeOps
);
1550 insertVector(Result
.getVector(I
), J
, Sum
, Builder
));
1554 Result
.addNumComputeOps(NumComputeOps
);
1557 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1558 /// copying it to a new location. This new or otherwise the original location
1560 Value
*getNonAliasingPointer(LoadInst
*Load
, StoreInst
*Store
,
1562 MemoryLocation StoreLoc
= MemoryLocation::get(Store
);
1563 MemoryLocation LoadLoc
= MemoryLocation::get(Load
);
1565 // If we can statically determine noalias we're good.
1566 if (AA
->isNoAlias(LoadLoc
, StoreLoc
))
1567 return Load
->getPointerOperand();
1569 // Create code to check if the memory locations of the Load and Store
1570 // overlap and if they do, copy Load's operand to a new buffer.
1572 // First, create new blocks for 2n part of the check and the copy.
1573 BasicBlock
*Check0
= MatMul
->getParent();
1574 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1575 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1576 // as we adjust Check0 and Check1's branches.
1577 SmallVector
<DominatorTree::UpdateType
, 4> DTUpdates
;
1578 for (BasicBlock
*Succ
: successors(Check0
))
1579 DTUpdates
.push_back({DT
->Delete
, Check0
, Succ
});
1581 BasicBlock
*Check1
=
1582 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1583 nullptr, "alias_cont");
1585 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1587 BasicBlock
*Fusion
=
1588 SplitBlock(MatMul
->getParent(), MatMul
, (DomTreeUpdater
*)nullptr, LI
,
1589 nullptr, "no_alias");
1591 // Check if the loaded memory location begins before the end of the store
1592 // location. If the condition holds, they might overlap, otherwise they are
1593 // guaranteed to not overlap.
1594 IRBuilder
<> Builder(MatMul
);
1595 Check0
->getTerminator()->eraseFromParent();
1596 Builder
.SetInsertPoint(Check0
);
1597 Type
*IntPtrTy
= Builder
.getIntPtrTy(Load
->getModule()->getDataLayout());
1598 Value
*StoreBegin
= Builder
.CreatePtrToInt(
1599 const_cast<Value
*>(StoreLoc
.Ptr
), IntPtrTy
, "store.begin");
1600 Value
*StoreEnd
= Builder
.CreateAdd(
1601 StoreBegin
, ConstantInt::get(IntPtrTy
, StoreLoc
.Size
.getValue()),
1602 "store.end", true, true);
1603 Value
*LoadBegin
= Builder
.CreatePtrToInt(const_cast<Value
*>(LoadLoc
.Ptr
),
1604 IntPtrTy
, "load.begin");
1605 Builder
.CreateCondBr(Builder
.CreateICmpULT(LoadBegin
, StoreEnd
), Check1
,
1608 // Check if the store begins before the end of the load location. If the
1609 // condition holds, they alias, otherwise they are guaranteed to not
1611 Check1
->getTerminator()->eraseFromParent();
1612 Builder
.SetInsertPoint(Check1
, Check1
->begin());
1613 Value
*LoadEnd
= Builder
.CreateAdd(
1614 LoadBegin
, ConstantInt::get(IntPtrTy
, LoadLoc
.Size
.getValue()),
1615 "load.end", true, true);
1616 Builder
.CreateCondBr(Builder
.CreateICmpULT(StoreBegin
, LoadEnd
), Copy
,
1619 // Copy load operand to new alloca.
1620 Builder
.SetInsertPoint(Copy
, Copy
->begin());
1621 auto *VT
= cast
<FixedVectorType
>(Load
->getType());
1622 // Use an array type for the alloca, to avoid potentially huge alignment
1623 // requirements for large vector types.
1624 auto *ArrayTy
= ArrayType::get(VT
->getElementType(), VT
->getNumElements());
1625 AllocaInst
*Alloca
=
1626 Builder
.CreateAlloca(ArrayTy
, Load
->getPointerAddressSpace());
1628 Builder
.CreateMemCpy(Alloca
, Alloca
->getAlign(), Load
->getPointerOperand(),
1629 Load
->getAlign(), LoadLoc
.Size
.getValue());
1630 Builder
.SetInsertPoint(Fusion
, Fusion
->begin());
1631 PHINode
*PHI
= Builder
.CreatePHI(Load
->getPointerOperandType(), 3);
1632 PHI
->addIncoming(Load
->getPointerOperand(), Check0
);
1633 PHI
->addIncoming(Load
->getPointerOperand(), Check1
);
1634 PHI
->addIncoming(Alloca
, Copy
);
1637 DTUpdates
.push_back({DT
->Insert
, Check0
, Check1
});
1638 DTUpdates
.push_back({DT
->Insert
, Check0
, Fusion
});
1639 DTUpdates
.push_back({DT
->Insert
, Check1
, Copy
});
1640 DTUpdates
.push_back({DT
->Insert
, Check1
, Fusion
});
1641 DT
->applyUpdates(DTUpdates
);
1645 bool isFusionProfitable(CallInst
*MatMul
) {
1649 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1650 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1652 const unsigned R
= LShape
.NumRows
;
1653 const unsigned C
= RShape
.NumColumns
;
1654 const unsigned M
= LShape
.NumColumns
;
1655 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1657 const unsigned VF
= std::max
<unsigned>(
1658 TTI
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector
)
1660 EltType
->getPrimitiveSizeInBits().getFixedValue(),
1663 // Cost model for tiling
1665 // For tiling to be beneficial, we need reuse either along the R or
1666 // the C axis. We vectorize along the R axis so that means at least
1668 // TODO: Also consider cost of copying if operands alias.
1669 if (R
<= VF
&& C
== 1)
1671 // Then we need enough elements to exceed the number of vector
1672 // registers we have. Note that this is an oversimplification since
1673 // fusing also takes some extra loads which may exceed the number of
1674 // reloads necessary.
1675 unsigned Op0Regs
= (R
+ VF
- 1) / VF
* M
;
1676 unsigned Op1Regs
= (M
+ VF
- 1) / VF
* C
;
1677 return Op0Regs
+ Op1Regs
>
1678 TTI
.getNumberOfRegisters(TTI
.getRegisterClassForType(true));
1681 MatrixTy
getZeroMatrix(Type
*EltType
, unsigned R
, unsigned C
) {
1683 auto *ColumType
= FixedVectorType::get(EltType
, R
);
1684 for (unsigned I
= 0; I
< C
; ++I
)
1685 Res
.addVector(ConstantAggregateZero::get(ColumType
));
1689 void createTiledLoops(CallInst
*MatMul
, Value
*LPtr
, ShapeInfo LShape
,
1690 Value
*RPtr
, ShapeInfo RShape
, StoreInst
*Store
) {
1691 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1693 // Create the main tiling loop nest.
1694 TileInfo
TI(LShape
.NumRows
, RShape
.NumColumns
, LShape
.NumColumns
, TileSize
);
1695 DomTreeUpdater
DTU(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
1696 Instruction
*InsertI
= cast
<Instruction
>(MatMul
);
1697 BasicBlock
*Start
= InsertI
->getParent();
1699 SplitBlock(InsertI
->getParent(), InsertI
, DT
, LI
, nullptr, "continue");
1700 IRBuilder
<> Builder(MatMul
);
1701 BasicBlock
*InnerBody
= TI
.CreateTiledLoops(Start
, End
, Builder
, DTU
, *LI
);
1704 FixedVectorType::get(MatMul
->getType()->getScalarType(), TileSize
);
1705 MatrixTy TileResult
;
1706 // Insert in the inner loop header.
1707 Builder
.SetInsertPoint(TI
.KLoop
.Header
->getTerminator());
1708 // Create PHI nodes for the result columns to accumulate across iterations.
1709 SmallVector
<PHINode
*, 4> ColumnPhis
;
1710 for (unsigned I
= 0; I
< TileSize
; I
++) {
1711 auto *Phi
= Builder
.CreatePHI(TileVecTy
, 2, "result.vec." + Twine(I
));
1712 Phi
->addIncoming(ConstantAggregateZero::get(TileVecTy
),
1713 TI
.RowLoop
.Header
->getSingleSuccessor());
1714 TileResult
.addVector(Phi
);
1715 ColumnPhis
.push_back(Phi
);
1718 // Insert in the inner loop body, which computes
1719 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1720 Builder
.SetInsertPoint(InnerBody
->getTerminator());
1721 // Load tiles of the operands.
1723 loadMatrix(LPtr
, {}, false, LShape
, TI
.RowLoop
.Index
, TI
.KLoop
.Index
,
1724 {TileSize
, TileSize
}, EltType
, Builder
);
1726 loadMatrix(RPtr
, {}, false, RShape
, TI
.KLoop
.Index
, TI
.ColumnLoop
.Index
,
1727 {TileSize
, TileSize
}, EltType
, Builder
);
1728 emitMatrixMultiply(TileResult
, A
, B
, Builder
, true, false,
1729 getFastMathFlags(MatMul
));
1730 // Store result after the inner loop is done.
1731 Builder
.SetInsertPoint(TI
.RowLoop
.Latch
->getTerminator());
1732 storeMatrix(TileResult
, Store
->getPointerOperand(), Store
->getAlign(),
1733 Store
->isVolatile(), {LShape
.NumRows
, RShape
.NumColumns
},
1734 TI
.RowLoop
.Index
, TI
.ColumnLoop
.Index
, EltType
, Builder
);
1736 for (unsigned I
= 0; I
< TileResult
.getNumVectors(); I
++)
1737 ColumnPhis
[I
]->addIncoming(TileResult
.getVector(I
), TI
.KLoop
.Latch
);
1739 // Force unrolling of a few iterations of the inner loop, to make sure there
1740 // is enough work per iteration.
1741 // FIXME: The unroller should make this decision directly instead, but
1742 // currently the cost-model is not up to the task.
1743 unsigned InnerLoopUnrollCount
= std::min(10u, LShape
.NumColumns
/ TileSize
);
1744 addStringMetadataToLoop(LI
->getLoopFor(TI
.KLoop
.Header
),
1745 "llvm.loop.unroll.count", InnerLoopUnrollCount
);
1748 void emitSIMDTiling(CallInst
*MatMul
, LoadInst
*LoadOp0
, LoadInst
*LoadOp1
,
1750 SmallPtrSetImpl
<Instruction
*> &FusedInsts
) {
1751 assert(MatrixLayout
== MatrixLayoutTy::ColumnMajor
&&
1752 "Tiling only supported for column-major matrixes at the moment!");
1753 if (!isFusionProfitable(MatMul
))
1756 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1757 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1759 const unsigned R
= LShape
.NumRows
;
1760 const unsigned C
= RShape
.NumColumns
;
1761 const unsigned M
= LShape
.NumColumns
;
1762 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1764 Value
*APtr
= getNonAliasingPointer(LoadOp0
, Store
, MatMul
);
1765 Value
*BPtr
= getNonAliasingPointer(LoadOp1
, Store
, MatMul
);
1766 Value
*CPtr
= Store
->getPointerOperand();
1768 if (TileUseLoops
&& (R
% TileSize
== 0 && C
% TileSize
== 0))
1769 createTiledLoops(MatMul
, APtr
, LShape
, BPtr
, RShape
, Store
);
1771 IRBuilder
<> Builder(Store
);
1772 for (unsigned J
= 0; J
< C
; J
+= TileSize
)
1773 for (unsigned I
= 0; I
< R
; I
+= TileSize
) {
1774 const unsigned TileR
= std::min(R
- I
, unsigned(TileSize
));
1775 const unsigned TileC
= std::min(C
- J
, unsigned(TileSize
));
1776 MatrixTy Res
= getZeroMatrix(EltType
, TileR
, TileC
);
1778 for (unsigned K
= 0; K
< M
; K
+= TileSize
) {
1779 const unsigned TileM
= std::min(M
- K
, unsigned(TileSize
));
1781 loadMatrix(APtr
, LoadOp0
->getAlign(), LoadOp0
->isVolatile(),
1782 LShape
, Builder
.getInt64(I
), Builder
.getInt64(K
),
1783 {TileR
, TileM
}, EltType
, Builder
);
1785 loadMatrix(BPtr
, LoadOp1
->getAlign(), LoadOp1
->isVolatile(),
1786 RShape
, Builder
.getInt64(K
), Builder
.getInt64(J
),
1787 {TileM
, TileC
}, EltType
, Builder
);
1788 emitMatrixMultiply(Res
, A
, B
, Builder
, true, false,
1789 getFastMathFlags(MatMul
));
1791 storeMatrix(Res
, CPtr
, Store
->getAlign(), Store
->isVolatile(), {R
, M
},
1792 Builder
.getInt64(I
), Builder
.getInt64(J
), EltType
,
1797 // Mark eliminated instructions as fused and remove them.
1798 FusedInsts
.insert(Store
);
1799 FusedInsts
.insert(MatMul
);
1800 Store
->eraseFromParent();
1801 MatMul
->eraseFromParent();
1802 if (LoadOp0
->hasNUses(0)) {
1803 FusedInsts
.insert(LoadOp0
);
1804 LoadOp0
->eraseFromParent();
1806 if (LoadOp1
!= LoadOp0
&& LoadOp1
->hasNUses(0)) {
1807 FusedInsts
.insert(LoadOp1
);
1808 LoadOp1
->eraseFromParent();
1812 /// Try to lower matrix multiply chains by fusing operations.
1814 /// Call finalizeLowering on lowered instructions. Instructions that are
1815 /// completely eliminated by fusion are added to \p FusedInsts.
1816 void LowerMatrixMultiplyFused(CallInst
*MatMul
,
1817 SmallPtrSetImpl
<Instruction
*> &FusedInsts
) {
1818 if (!FuseMatrix
|| !DT
)
1821 assert(AA
&& LI
&& "Analyses should be available");
1823 Value
*A
= MatMul
->getArgOperand(0);
1824 Value
*B
= MatMul
->getArgOperand(1);
1826 // We can fold the transpose into the operand that is used to fetch scalars.
1828 if (MatrixLayout
== MatrixLayoutTy::ColumnMajor
1829 ? match(B
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(T
)))
1830 : match(A
, m_Intrinsic
<Intrinsic::matrix_transpose
>(m_Value(T
)))) {
1831 IRBuilder
<> Builder(MatMul
);
1832 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1833 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1834 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1835 const unsigned R
= LShape
.NumRows
;
1836 const unsigned M
= LShape
.NumColumns
;
1837 const unsigned C
= RShape
.NumColumns
;
1843 if (MatrixLayout
== MatrixLayoutTy::ColumnMajor
) {
1844 MA
= getMatrix(A
, ShapeInfo(R
, M
), Builder
);
1845 MB
= getMatrix(T
, ShapeInfo(C
, M
), Builder
);
1848 MA
= getMatrix(T
, ShapeInfo(R
, M
), Builder
);
1849 MB
= getMatrix(B
, ShapeInfo(C
, M
), Builder
);
1853 // Initialize the output
1854 MatrixTy
Result(R
, C
, EltType
);
1856 emitMatrixMultiply(Result
, MA
, MB
, Builder
, false, true,
1857 getFastMathFlags(MatMul
));
1859 FusedInsts
.insert(MatMul
);
1860 if (Transpose
->hasOneUse()) {
1861 FusedInsts
.insert(cast
<Instruction
>(Transpose
));
1862 ToRemove
.push_back(cast
<Instruction
>(Transpose
));
1863 // TODO: add a fake entry for the folded instruction so that this is
1864 // included in the expression in the remark.
1865 Inst2ColumnMatrix
[Transpose
] = MatrixTy(M
, C
, EltType
);
1867 finalizeLowering(MatMul
, Result
, Builder
);
1871 if (!MatMul
->hasOneUse() || MatrixLayout
!= MatrixLayoutTy::ColumnMajor
)
1874 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1875 // since the single store user will be lowered as part of this.
1876 auto *LoadOp0
= dyn_cast
<LoadInst
>(A
);
1877 auto *LoadOp1
= dyn_cast
<LoadInst
>(B
);
1878 auto *Store
= dyn_cast
<StoreInst
>(*MatMul
->user_begin());
1879 if (LoadOp0
&& LoadOp1
&& Store
) {
1880 // The store address must dominate the MatMul instruction, otherwise
1881 // we create invalid IR.
1882 SetVector
<Value
*> WorkList
;
1883 WorkList
.insert(Store
->getOperand(1));
1884 SmallVector
<Instruction
*> ToHoist
;
1885 for (unsigned I
= 0; I
!= WorkList
.size(); ++I
) {
1886 Value
*Current
= WorkList
[I
];
1887 auto *CurrI
= dyn_cast
<Instruction
>(Current
);
1890 if (isa
<PHINode
>(CurrI
))
1892 if (DT
->dominates(CurrI
, MatMul
))
1894 if (CurrI
->mayHaveSideEffects() || CurrI
->mayReadFromMemory())
1896 ToHoist
.push_back(CurrI
);
1897 WorkList
.insert(CurrI
->op_begin(), CurrI
->op_end());
1900 sort(ToHoist
, [this](Instruction
*A
, Instruction
*B
) {
1901 return DT
->dominates(A
, B
);
1903 for (Instruction
*I
: ToHoist
)
1904 I
->moveBefore(MatMul
);
1906 emitSIMDTiling(MatMul
, LoadOp0
, LoadOp1
, Store
, FusedInsts
);
1911 /// Lowers llvm.matrix.multiply.
1912 void LowerMultiply(CallInst
*MatMul
) {
1913 IRBuilder
<> Builder(MatMul
);
1914 auto *EltType
= cast
<VectorType
>(MatMul
->getType())->getElementType();
1915 ShapeInfo
LShape(MatMul
->getArgOperand(2), MatMul
->getArgOperand(3));
1916 ShapeInfo
RShape(MatMul
->getArgOperand(3), MatMul
->getArgOperand(4));
1918 const MatrixTy
&Lhs
= getMatrix(MatMul
->getArgOperand(0), LShape
, Builder
);
1919 const MatrixTy
&Rhs
= getMatrix(MatMul
->getArgOperand(1), RShape
, Builder
);
1920 assert(Lhs
.getElementType() == Rhs
.getElementType() &&
1921 "Matrix multiply argument element types do not match.");
1923 const unsigned R
= LShape
.NumRows
;
1924 const unsigned C
= RShape
.NumColumns
;
1925 assert(LShape
.NumColumns
== RShape
.NumRows
);
1927 // Initialize the output
1928 MatrixTy
Result(R
, C
, EltType
);
1929 assert(Lhs
.getElementType() == Result
.getElementType() &&
1930 "Matrix multiply result element type does not match arguments.");
1932 emitMatrixMultiply(Result
, Lhs
, Rhs
, Builder
, false, false,
1933 getFastMathFlags(MatMul
));
1934 finalizeLowering(MatMul
, Result
, Builder
);
1937 /// Lowers llvm.matrix.transpose.
1938 void LowerTranspose(CallInst
*Inst
) {
1940 IRBuilder
<> Builder(Inst
);
1941 Value
*InputVal
= Inst
->getArgOperand(0);
1942 VectorType
*VectorTy
= cast
<VectorType
>(InputVal
->getType());
1943 ShapeInfo
ArgShape(Inst
->getArgOperand(1), Inst
->getArgOperand(2));
1944 MatrixTy InputMatrix
= getMatrix(InputVal
, ArgShape
, Builder
);
1946 const unsigned NewNumVecs
=
1947 InputMatrix
.isColumnMajor() ? ArgShape
.NumRows
: ArgShape
.NumColumns
;
1948 const unsigned NewNumElts
=
1949 InputMatrix
.isColumnMajor() ? ArgShape
.NumColumns
: ArgShape
.NumRows
;
1951 for (unsigned I
= 0; I
< NewNumVecs
; ++I
) {
1952 // Build a single result vector. First initialize it.
1953 Value
*ResultVector
= PoisonValue::get(
1954 FixedVectorType::get(VectorTy
->getElementType(), NewNumElts
));
1955 // Go through the old elements and insert it into the resulting vector.
1956 for (auto J
: enumerate(InputMatrix
.vectors())) {
1957 Value
*Elt
= Builder
.CreateExtractElement(J
.value(), I
);
1958 // Row and column indices are transposed.
1960 Builder
.CreateInsertElement(ResultVector
, Elt
, J
.index());
1962 Result
.addVector(ResultVector
);
1965 // TODO: Improve estimate of operations needed for transposes. Currently we
1966 // just count the insertelement/extractelement instructions, but do not
1967 // account for later simplifications/combines.
1970 Result
.addNumComputeOps(2 * ArgShape
.NumRows
* ArgShape
.NumColumns
)
1971 .addNumExposedTransposes(1),
1975 /// Lower load instructions, if shape information is available.
1976 bool VisitLoad(LoadInst
*Inst
, Value
*Ptr
, IRBuilder
<> &Builder
) {
1977 auto I
= ShapeMap
.find(Inst
);
1978 if (I
== ShapeMap
.end())
1981 LowerLoad(Inst
, Ptr
, Inst
->getAlign(),
1982 Builder
.getInt64(I
->second
.getStride()), Inst
->isVolatile(),
1987 bool VisitStore(StoreInst
*Inst
, Value
*StoredVal
, Value
*Ptr
,
1988 IRBuilder
<> &Builder
) {
1989 auto I
= ShapeMap
.find(StoredVal
);
1990 if (I
== ShapeMap
.end())
1993 LowerStore(Inst
, StoredVal
, Ptr
, Inst
->getAlign(),
1994 Builder
.getInt64(I
->second
.getStride()), Inst
->isVolatile(),
1999 /// Lower binary operators, if shape information is available.
2000 bool VisitBinaryOperator(BinaryOperator
*Inst
) {
2001 auto I
= ShapeMap
.find(Inst
);
2002 if (I
== ShapeMap
.end())
2005 Value
*Lhs
= Inst
->getOperand(0);
2006 Value
*Rhs
= Inst
->getOperand(1);
2008 IRBuilder
<> Builder(Inst
);
2009 ShapeInfo
&Shape
= I
->second
;
2012 MatrixTy A
= getMatrix(Lhs
, Shape
, Builder
);
2013 MatrixTy B
= getMatrix(Rhs
, Shape
, Builder
);
2014 assert(A
.isColumnMajor() == B
.isColumnMajor() &&
2015 Result
.isColumnMajor() == A
.isColumnMajor() &&
2016 "operands must agree on matrix layout");
2018 Builder
.setFastMathFlags(getFastMathFlags(Inst
));
2020 // Helper to perform binary op on vectors.
2021 auto BuildVectorOp
= [&Builder
, Inst
](Value
*LHS
, Value
*RHS
) {
2022 switch (Inst
->getOpcode()) {
2023 case Instruction::Add
:
2024 return Builder
.CreateAdd(LHS
, RHS
);
2025 case Instruction::Mul
:
2026 return Builder
.CreateMul(LHS
, RHS
);
2027 case Instruction::Sub
:
2028 return Builder
.CreateSub(LHS
, RHS
);
2029 case Instruction::FAdd
:
2030 return Builder
.CreateFAdd(LHS
, RHS
);
2031 case Instruction::FMul
:
2032 return Builder
.CreateFMul(LHS
, RHS
);
2033 case Instruction::FSub
:
2034 return Builder
.CreateFSub(LHS
, RHS
);
2036 llvm_unreachable("Unsupported binary operator for matrix");
2040 for (unsigned I
= 0; I
< Shape
.getNumVectors(); ++I
)
2041 Result
.addVector(BuildVectorOp(A
.getVector(I
), B
.getVector(I
)));
2043 finalizeLowering(Inst
,
2044 Result
.addNumComputeOps(getNumOps(Result
.getVectorTy()) *
2045 Result
.getNumVectors()),
2050 /// Lower unary operators, if shape information is available.
2051 bool VisitUnaryOperator(UnaryOperator
*Inst
) {
2052 auto I
= ShapeMap
.find(Inst
);
2053 if (I
== ShapeMap
.end())
2056 Value
*Op
= Inst
->getOperand(0);
2058 IRBuilder
<> Builder(Inst
);
2059 ShapeInfo
&Shape
= I
->second
;
2062 MatrixTy M
= getMatrix(Op
, Shape
, Builder
);
2064 Builder
.setFastMathFlags(getFastMathFlags(Inst
));
2066 // Helper to perform unary op on vectors.
2067 auto BuildVectorOp
= [&Builder
, Inst
](Value
*Op
) {
2068 switch (Inst
->getOpcode()) {
2069 case Instruction::FNeg
:
2070 return Builder
.CreateFNeg(Op
);
2072 llvm_unreachable("Unsupported unary operator for matrix");
2076 for (unsigned I
= 0; I
< Shape
.getNumVectors(); ++I
)
2077 Result
.addVector(BuildVectorOp(M
.getVector(I
)));
2079 finalizeLowering(Inst
,
2080 Result
.addNumComputeOps(getNumOps(Result
.getVectorTy()) *
2081 Result
.getNumVectors()),
2086 /// Helper to linearize a matrix expression tree into a string. Currently
2087 /// matrix expressions are linarized by starting at an expression leaf and
2088 /// linearizing bottom up.
2089 struct ExprLinearizer
{
2090 unsigned LengthToBreak
= 100;
2092 raw_string_ostream Stream
;
2093 unsigned LineLength
= 0;
2094 const DataLayout
&DL
;
2096 /// Mapping from instructions to matrixes. It is used to identify
2097 /// matrix instructions.
2098 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
;
2100 /// Mapping from values to the leaves of all expressions that the value is
2102 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
;
2104 /// Set of matrix expressions in the scope of a given DISubprogram.
2105 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
;
2107 /// Leaf node of the expression to linearize.
2110 /// Used to keep track of sub-expressions that get reused while linearizing
2111 /// the expression. Re-used sub-expressions are marked as (reused).
2112 SmallPtrSet
<Value
*, 8> ReusedExprs
;
2114 ExprLinearizer(const DataLayout
&DL
,
2115 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
,
2116 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
,
2117 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2119 : Stream(Str
), DL(DL
), Inst2Matrix(Inst2Matrix
), Shared(Shared
),
2120 ExprsInSubprogram(ExprsInSubprogram
), Leaf(Leaf
) {}
2122 void indent(unsigned N
) {
2124 for (unsigned i
= 0; i
< N
; i
++)
2133 void maybeIndent(unsigned Indent
) {
2134 if (LineLength
>= LengthToBreak
)
2137 if (LineLength
== 0)
2141 void write(StringRef S
) {
2142 LineLength
+= S
.size();
2146 Value
*getUnderlyingObjectThroughLoads(Value
*V
) {
2147 if (Value
*Ptr
= getPointerOperand(V
))
2148 return getUnderlyingObjectThroughLoads(Ptr
);
2149 else if (V
->getType()->isPointerTy())
2150 return getUnderlyingObject(V
);
2154 /// Returns true if \p V is a matrix value in the given subprogram.
2155 bool isMatrix(Value
*V
) const { return ExprsInSubprogram
.count(V
); }
2157 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2159 void prettyPrintMatrixType(Value
*V
, raw_string_ostream
&SS
) {
2160 auto M
= Inst2Matrix
.find(V
);
2161 if (M
== Inst2Matrix
.end())
2164 SS
<< M
->second
.getNumRows();
2166 SS
<< M
->second
.getNumColumns();
2170 /// Write the called function name. Handles calls to llvm.matrix.*
2171 /// specially: we write the name, followed by the dimensions of the input
2172 /// matrixes, followed by the scalar type name.
2173 void writeFnName(CallInst
*CI
) {
2174 if (!CI
->getCalledFunction())
2175 write("<no called fn>");
2177 StringRef Name
= CI
->getCalledFunction()->getName();
2178 if (!Name
.starts_with("llvm.matrix")) {
2182 auto *II
= cast
<IntrinsicInst
>(CI
);
2183 write(Intrinsic::getBaseName(II
->getIntrinsicID())
2184 .drop_front(StringRef("llvm.matrix.").size()));
2187 raw_string_ostream
SS(Tmp
);
2189 switch (II
->getIntrinsicID()) {
2190 case Intrinsic::matrix_multiply
:
2191 prettyPrintMatrixType(II
->getOperand(0), SS
);
2193 prettyPrintMatrixType(II
->getOperand(1), SS
);
2194 SS
<< "." << *II
->getType()->getScalarType();
2196 case Intrinsic::matrix_transpose
:
2197 prettyPrintMatrixType(II
->getOperand(0), SS
);
2198 SS
<< "." << *II
->getType()->getScalarType();
2200 case Intrinsic::matrix_column_major_load
:
2201 prettyPrintMatrixType(II
, SS
);
2202 SS
<< "." << *II
->getType()->getScalarType();
2204 case Intrinsic::matrix_column_major_store
:
2205 prettyPrintMatrixType(II
->getOperand(0), SS
);
2206 SS
<< "." << *II
->getOperand(0)->getType()->getScalarType();
2209 llvm_unreachable("Unhandled case");
2216 unsigned getNumShapeArgs(CallInst
*CI
) const {
2217 if (IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
)) {
2218 switch (II
->getIntrinsicID()) {
2219 case Intrinsic::matrix_multiply
:
2221 case Intrinsic::matrix_transpose
:
2223 case Intrinsic::matrix_column_major_load
:
2224 case Intrinsic::matrix_column_major_store
:
2233 /// Special printing for values: for pointers, we print if they refer to an
2234 /// (function) external address or a stack address, for other values we
2235 /// either print the constant or "scalar"/"matrix" for other values.
2236 void write(Value
*V
) {
2237 V
= getUnderlyingObjectThroughLoads(V
);
2238 if (V
->getType()->isPointerTy()) {
2239 if (isa
<AllocaInst
>(V
)) {
2240 Stream
<< "stack addr";
2241 LineLength
+= StringRef("stack addr").size();
2244 LineLength
+= StringRef("addr").size();
2246 if (!V
->getName().empty()) {
2247 Stream
<< " %" << V
->getName() << "";
2248 LineLength
+= V
->getName().size() + 2;
2254 raw_string_ostream
TmpStream(Tmp
);
2256 if (auto *CI
= dyn_cast
<ConstantInt
>(V
))
2257 TmpStream
<< CI
->getValue();
2258 else if (isa
<Constant
>(V
))
2259 TmpStream
<< "constant";
2262 TmpStream
<< "matrix";
2264 TmpStream
<< "scalar";
2267 Tmp
= std::string(StringRef(Tmp
).trim());
2268 LineLength
+= Tmp
.size();
2272 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2273 /// Expressions that are re-used multiple times are prefixed with (reused)
2274 /// at the re-used root instruction.
2275 void linearizeExpr(Value
*Expr
, unsigned Indent
, bool ParentReused
,
2276 bool ParentShared
) {
2277 auto *I
= cast
<Instruction
>(Expr
);
2278 maybeIndent(Indent
);
2279 SmallVector
<Value
*, 8> Ops
;
2281 // Is Expr shared with other expression leaves?
2282 bool ExprShared
= false;
2284 // Deal with shared subtrees. Mark them as shared, if required.
2285 if (!ParentShared
) {
2286 auto SI
= Shared
.find(Expr
);
2287 assert(SI
!= Shared
.end() && SI
->second
.count(Leaf
));
2289 for (Value
*S
: SI
->second
) {
2292 DebugLoc DL
= cast
<Instruction
>(S
)->getDebugLoc();
2293 write("shared with remark at line " + std::to_string(DL
.getLine()) +
2294 " column " + std::to_string(DL
.getCol()) + " (");
2296 ExprShared
= SI
->second
.size() > 1;
2299 bool Reused
= !ReusedExprs
.insert(Expr
).second
;
2300 if (Reused
&& !ParentReused
)
2303 if (auto *CI
= dyn_cast
<CallInst
>(I
)) {
2306 Ops
.append(CI
->arg_begin(), CI
->arg_end() - getNumShapeArgs(CI
));
2307 } else if (isa
<BitCastInst
>(Expr
)) {
2308 // Special case bitcasts, which are used to materialize matrixes from
2313 Ops
.append(I
->value_op_begin(), I
->value_op_end());
2314 write(std::string(I
->getOpcodeName()));
2317 write(std::string("("));
2319 unsigned NumOpsToBreak
= 1;
2320 if (match(Expr
, m_Intrinsic
<Intrinsic::matrix_column_major_load
>()))
2323 for (Value
*Op
: Ops
) {
2324 if (Ops
.size() > NumOpsToBreak
)
2327 maybeIndent(Indent
+ 1);
2329 linearizeExpr(Op
, Indent
+ 1, Reused
, ExprShared
);
2332 if (Op
!= Ops
.back())
2339 const std::string
&getResult() {
2345 /// Generate remarks for matrix operations in a function. To generate remarks
2346 /// for matrix expressions, the following approach is used:
2347 /// 1. Use the inlined-at debug information to group matrix operations to the
2348 /// DISubprograms they are contained in.
2349 /// 2. Collect leaves of matrix expressions (done in
2350 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2351 // mapping. Leaves are lowered matrix instructions without other matrix
2352 // users (like stores) in the current subprogram.
2353 /// 3. For each leaf, create a remark containing a linearizied version of the
2354 /// matrix expression. The expression is linearized by a recursive
2355 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2356 /// that multiple leaves can share sub-expressions. Shared subexpressions
2357 /// are explicitly marked as shared().
2358 struct RemarkGenerator
{
2359 const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
;
2360 OptimizationRemarkEmitter
&ORE
;
2362 const DataLayout
&DL
;
2364 RemarkGenerator(const MapVector
<Value
*, MatrixTy
> &Inst2Matrix
,
2365 OptimizationRemarkEmitter
&ORE
, Function
&Func
)
2366 : Inst2Matrix(Inst2Matrix
), ORE(ORE
), Func(Func
),
2367 DL(Func
.getParent()->getDataLayout()) {}
2369 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2370 /// instructions in Inst2Matrix returning void or without any users in
2371 /// \p ExprsInSubprogram. Currently that should only include stores.
2372 SmallVector
<Value
*, 4>
2373 getExpressionLeaves(const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
) {
2374 SmallVector
<Value
*, 4> Leaves
;
2375 for (auto *Expr
: ExprsInSubprogram
)
2376 if (Expr
->getType()->isVoidTy() ||
2377 !any_of(Expr
->users(), [&ExprsInSubprogram
](User
*U
) {
2378 return ExprsInSubprogram
.count(U
);
2380 Leaves
.push_back(Expr
);
2384 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2385 /// to all visited expressions in \p Shared. Limit the matrix operations to
2386 /// the ones in \p ExprsInSubprogram.
2387 void collectSharedInfo(Value
*Leaf
, Value
*V
,
2388 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2389 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
) {
2391 if (!ExprsInSubprogram
.count(V
))
2394 auto I
= Shared
.insert({V
, {}});
2395 I
.first
->second
.insert(Leaf
);
2397 for (Value
*Op
: cast
<Instruction
>(V
)->operand_values())
2398 collectSharedInfo(Leaf
, Op
, ExprsInSubprogram
, Shared
);
2401 /// Calculate the number of exclusive and shared op counts for expression
2402 /// starting at \p V. Expressions used multiple times are counted once.
2403 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2404 std::pair
<OpInfoTy
, OpInfoTy
>
2405 sumOpInfos(Value
*Root
, SmallPtrSetImpl
<Value
*> &ReusedExprs
,
2406 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2407 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
) const {
2408 if (!ExprsInSubprogram
.count(Root
))
2411 // Already counted this expression. Stop.
2412 if (!ReusedExprs
.insert(Root
).second
)
2415 OpInfoTy SharedCount
;
2418 auto I
= Shared
.find(Root
);
2419 auto CM
= Inst2Matrix
.find(Root
);
2420 if (I
->second
.size() == 1)
2421 Count
= CM
->second
.getOpInfo();
2423 SharedCount
= CM
->second
.getOpInfo();
2425 for (Value
*Op
: cast
<Instruction
>(Root
)->operand_values()) {
2426 auto C
= sumOpInfos(Op
, ReusedExprs
, ExprsInSubprogram
, Shared
);
2428 SharedCount
+= C
.second
;
2430 return {Count
, SharedCount
};
2433 void emitRemarks() {
2434 if (!ORE
.allowExtraAnalysis(DEBUG_TYPE
))
2437 // Map matrix operations to their containting subprograms, by traversing
2438 // the inlinedAt chain. If the function does not have a DISubprogram, we
2439 // only map them to the containing function.
2440 MapVector
<DISubprogram
*, SmallVector
<Value
*, 8>> Subprog2Exprs
;
2441 for (const auto &KV
: Inst2Matrix
) {
2442 if (Func
.getSubprogram()) {
2443 auto *I
= cast
<Instruction
>(KV
.first
);
2444 DILocation
*Context
= I
->getDebugLoc();
2447 Subprog2Exprs
.insert({getSubprogram(Context
->getScope()), {}});
2448 I
.first
->second
.push_back(KV
.first
);
2449 Context
= DebugLoc(Context
).getInlinedAt();
2452 auto I
= Subprog2Exprs
.insert({nullptr, {}});
2453 I
.first
->second
.push_back(KV
.first
);
2456 for (auto &KV
: Subprog2Exprs
) {
2457 SmallSetVector
<Value
*, 32> ExprsInSubprogram(KV
.second
.begin(),
2459 auto Leaves
= getExpressionLeaves(ExprsInSubprogram
);
2461 DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> Shared
;
2462 for (Value
*Leaf
: Leaves
)
2463 collectSharedInfo(Leaf
, Leaf
, ExprsInSubprogram
, Shared
);
2465 // Generate remarks for each leaf.
2466 for (auto *L
: Leaves
) {
2468 DebugLoc Loc
= cast
<Instruction
>(L
)->getDebugLoc();
2469 DILocation
*Context
= cast
<Instruction
>(L
)->getDebugLoc();
2471 if (getSubprogram(Context
->getScope()) == KV
.first
) {
2475 Context
= DebugLoc(Context
).getInlinedAt();
2478 SmallPtrSet
<Value
*, 8> ReusedExprs
;
2479 OpInfoTy Counts
, SharedCounts
;
2480 std::tie(Counts
, SharedCounts
) =
2481 sumOpInfos(L
, ReusedExprs
, ExprsInSubprogram
, Shared
);
2483 OptimizationRemark
Rem(DEBUG_TYPE
, "matrix-lowered", Loc
,
2484 cast
<Instruction
>(L
)->getParent());
2486 Rem
<< "Lowered with ";
2487 Rem
<< ore::NV("NumStores", Counts
.NumStores
) << " stores, "
2488 << ore::NV("NumLoads", Counts
.NumLoads
) << " loads, "
2489 << ore::NV("NumComputeOps", Counts
.NumComputeOps
)
2491 << ore::NV("NumExposedTransposes", Counts
.NumExposedTransposes
)
2492 << " exposed transposes";
2494 if (SharedCounts
.NumStores
> 0 || SharedCounts
.NumLoads
> 0 ||
2495 SharedCounts
.NumComputeOps
> 0) {
2496 Rem
<< ",\nadditionally "
2497 << ore::NV("NumStores", SharedCounts
.NumStores
) << " stores, "
2498 << ore::NV("NumLoads", SharedCounts
.NumLoads
) << " loads, "
2499 << ore::NV("NumFPOps", SharedCounts
.NumComputeOps
)
2501 << " are shared with other expressions";
2504 Rem
<< ("\n" + linearize(L
, Shared
, ExprsInSubprogram
, DL
));
2512 const DenseMap
<Value
*, SmallPtrSet
<Value
*, 2>> &Shared
,
2513 const SmallSetVector
<Value
*, 32> &ExprsInSubprogram
,
2514 const DataLayout
&DL
) {
2515 ExprLinearizer
Lin(DL
, Inst2Matrix
, Shared
, ExprsInSubprogram
, L
);
2516 Lin
.linearizeExpr(L
, 0, false, false);
2517 return Lin
.getResult();
2523 PreservedAnalyses
LowerMatrixIntrinsicsPass::run(Function
&F
,
2524 FunctionAnalysisManager
&AM
) {
2525 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
2526 OptimizationRemarkEmitter
*ORE
= nullptr;
2527 AAResults
*AA
= nullptr;
2528 DominatorTree
*DT
= nullptr;
2529 LoopInfo
*LI
= nullptr;
2532 ORE
= &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
2533 AA
= &AM
.getResult
<AAManager
>(F
);
2534 DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
2535 LI
= &AM
.getResult
<LoopAnalysis
>(F
);
2538 LowerMatrixIntrinsics
LMT(F
, TTI
, AA
, DT
, LI
, ORE
);
2540 PreservedAnalyses PA
;
2542 PA
.preserve
<LoopAnalysis
>();
2543 PA
.preserve
<DominatorTreeAnalysis
>();
2547 return PreservedAnalyses::all();
2550 void LowerMatrixIntrinsicsPass::printPipeline(
2551 raw_ostream
&OS
, function_ref
<StringRef(StringRef
)> MapClassName2PassName
) {
2552 static_cast<PassInfoMixin
<LowerMatrixIntrinsicsPass
> *>(this)->printPipeline(
2553 OS
, MapClassName2PassName
);