[Hexagon] Better detection of impossible completions to perfect shuffles
[llvm-project.git] / mlir / lib / Conversion / LLVMCommon / VectorPattern.cpp
blobe95c702d79f387a528d5134538e1d24707a58c5e
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 using namespace mlir;
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
18 LLVM::detail::NDVectorTypeInfo
19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
20 LLVMTypeConverter &converter) {
21 assert(vectorType.getRank() > 1 && "expected >1D vector type");
22 NDVectorTypeInfo info;
23 info.llvmNDVectorTy = converter.convertType(vectorType);
24 if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25 info.llvmNDVectorTy = nullptr;
26 return info;
28 info.arraySizes.reserve(vectorType.getRank() - 1);
29 auto llvmTy = info.llvmNDVectorTy;
30 while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31 info.arraySizes.push_back(
32 llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33 llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
35 if (!LLVM::isCompatibleVectorType(llvmTy))
36 return info;
37 info.llvm1DVectorTy = llvmTy;
38 return info;
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
48 SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
49 unsigned linearIndex) {
50 SmallVector<int64_t, 4> res;
51 res.reserve(basis.size());
52 for (unsigned basisElement : llvm::reverse(basis)) {
53 res.push_back(linearIndex % basisElement);
54 linearIndex = linearIndex / basisElement;
56 if (linearIndex > 0)
57 return {};
58 std::reverse(res.begin(), res.end());
59 return res;
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
64 void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
65 OpBuilder &builder,
66 function_ref<void(ArrayRef<int64_t>)> fun) {
67 unsigned ub = 1;
68 for (auto s : info.arraySizes)
69 ub *= s;
70 for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71 auto coords = getCoordinates(info.arraySizes, linearIndex);
72 // Linear index is out of bounds, we are done.
73 if (coords.empty())
74 break;
75 assert(coords.size() == info.arraySizes.size());
76 fun(coords);
80 LogicalResult LLVM::detail::handleMultidimensionalVectors(
81 Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
82 std::function<Value(Type, ValueRange)> createOperand,
83 ConversionPatternRewriter &rewriter) {
84 auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
85 auto resultTypeInfo =
86 extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
87 auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
88 auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
89 auto loc = op->getLoc();
90 Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
91 nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
92 // For this unrolled `position` corresponding to the `linearIndex`^th
93 // element, extract operand vectors
94 SmallVector<Value, 4> extractedOperands;
95 for (const auto &operand : llvm::enumerate(operands)) {
96 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
97 loc, operand.value(), position));
99 Value newVal = createOperand(result1DVectorTy, extractedOperands);
100 desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
102 rewriter.replaceOp(op, desc);
103 return success();
106 LogicalResult LLVM::detail::vectorOneToOneRewrite(
107 Operation *op, StringRef targetOp, ValueRange operands,
108 ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
109 ConversionPatternRewriter &rewriter) {
110 assert(!operands.empty());
112 // Cannot convert ops if their operands are not of LLVM type.
113 if (!llvm::all_of(operands.getTypes(), isCompatibleType))
114 return failure();
116 auto llvmNDVectorTy = operands[0].getType();
117 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
118 return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
119 rewriter);
121 auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
122 ValueRange operands) {
123 return rewriter
124 .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
125 llvm1DVectorTy, targetAttrs)
126 ->getResult(0);
129 return handleMultidimensionalVectors(op, operands, typeConverter, callback,
130 rewriter);