1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
18 LLVM::detail::NDVectorTypeInfo
19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType
,
20 LLVMTypeConverter
&converter
) {
21 assert(vectorType
.getRank() > 1 && "expected >1D vector type");
22 NDVectorTypeInfo info
;
23 info
.llvmNDVectorTy
= converter
.convertType(vectorType
);
24 if (!info
.llvmNDVectorTy
|| !LLVM::isCompatibleType(info
.llvmNDVectorTy
)) {
25 info
.llvmNDVectorTy
= nullptr;
28 info
.arraySizes
.reserve(vectorType
.getRank() - 1);
29 auto llvmTy
= info
.llvmNDVectorTy
;
30 while (llvmTy
.isa
<LLVM::LLVMArrayType
>()) {
31 info
.arraySizes
.push_back(
32 llvmTy
.cast
<LLVM::LLVMArrayType
>().getNumElements());
33 llvmTy
= llvmTy
.cast
<LLVM::LLVMArrayType
>().getElementType();
35 if (!LLVM::isCompatibleVectorType(llvmTy
))
37 info
.llvm1DVectorTy
= llvmTy
;
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
48 SmallVector
<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef
<int64_t> basis
,
49 unsigned linearIndex
) {
50 SmallVector
<int64_t, 4> res
;
51 res
.reserve(basis
.size());
52 for (unsigned basisElement
: llvm::reverse(basis
)) {
53 res
.push_back(linearIndex
% basisElement
);
54 linearIndex
= linearIndex
/ basisElement
;
58 std::reverse(res
.begin(), res
.end());
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
64 void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo
&info
,
66 function_ref
<void(ArrayRef
<int64_t>)> fun
) {
68 for (auto s
: info
.arraySizes
)
70 for (unsigned linearIndex
= 0; linearIndex
< ub
; ++linearIndex
) {
71 auto coords
= getCoordinates(info
.arraySizes
, linearIndex
);
72 // Linear index is out of bounds, we are done.
75 assert(coords
.size() == info
.arraySizes
.size());
80 LogicalResult
LLVM::detail::handleMultidimensionalVectors(
81 Operation
*op
, ValueRange operands
, LLVMTypeConverter
&typeConverter
,
82 std::function
<Value(Type
, ValueRange
)> createOperand
,
83 ConversionPatternRewriter
&rewriter
) {
84 auto resultNDVectorType
= op
->getResult(0).getType().cast
<VectorType
>();
86 extractNDVectorTypeInfo(resultNDVectorType
, typeConverter
);
87 auto result1DVectorTy
= resultTypeInfo
.llvm1DVectorTy
;
88 auto resultNDVectoryTy
= resultTypeInfo
.llvmNDVectorTy
;
89 auto loc
= op
->getLoc();
90 Value desc
= rewriter
.create
<LLVM::UndefOp
>(loc
, resultNDVectoryTy
);
91 nDVectorIterate(resultTypeInfo
, rewriter
, [&](ArrayRef
<int64_t> position
) {
92 // For this unrolled `position` corresponding to the `linearIndex`^th
93 // element, extract operand vectors
94 SmallVector
<Value
, 4> extractedOperands
;
95 for (const auto &operand
: llvm::enumerate(operands
)) {
96 extractedOperands
.push_back(rewriter
.create
<LLVM::ExtractValueOp
>(
97 loc
, operand
.value(), position
));
99 Value newVal
= createOperand(result1DVectorTy
, extractedOperands
);
100 desc
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, desc
, newVal
, position
);
102 rewriter
.replaceOp(op
, desc
);
106 LogicalResult
LLVM::detail::vectorOneToOneRewrite(
107 Operation
*op
, StringRef targetOp
, ValueRange operands
,
108 ArrayRef
<NamedAttribute
> targetAttrs
, LLVMTypeConverter
&typeConverter
,
109 ConversionPatternRewriter
&rewriter
) {
110 assert(!operands
.empty());
112 // Cannot convert ops if their operands are not of LLVM type.
113 if (!llvm::all_of(operands
.getTypes(), isCompatibleType
))
116 auto llvmNDVectorTy
= operands
[0].getType();
117 if (!llvmNDVectorTy
.isa
<LLVM::LLVMArrayType
>())
118 return oneToOneRewrite(op
, targetOp
, operands
, targetAttrs
, typeConverter
,
121 auto callback
= [op
, targetOp
, targetAttrs
, &rewriter
](Type llvm1DVectorTy
,
122 ValueRange operands
) {
124 .create(op
->getLoc(), rewriter
.getStringAttr(targetOp
), operands
,
125 llvm1DVectorTy
, targetAttrs
)
129 return handleMultidimensionalVectors(op
, operands
, typeConverter
, callback
,