From c3c326213e80abd6db9da83dbf0ab8452780705c Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 31 Jan 2025 15:47:47 -0800 Subject: [PATCH] [mlir][Vector] Fix `vector.shuffle` folder for poison indices (#124863) This PR fixes the folder of a `vector.shuffle` with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 48 +++++++++++++++++------------- mlir/test/Dialect/Vector/canonicalize.mlir | 27 +++++++++++++++++ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6a329499c711..93f89eda2da5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef idxArr, uint64_t begin, size_t width) { } OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { - VectorType v1Type = getV1VectorType(); + auto v1Type = getV1VectorType(); + auto v2Type = getV2VectorType(); + + assert(!v1Type.isScalable() && !v2Type.isScalable() && + "Vector shuffle does not support scalable vectors"); + // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding // but must be a canonicalization into a vector.broadcast. if (v1Type.getRank() == 0) return {}; - // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1 - if (!v1Type.isScalable() && - isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) + // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1. + auto mask = getMask(); + if (isStepIndexArray(mask, 0, v1Type.getDimSize(0))) return getV1(); - // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2 - if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() && - isStepIndexArray(getMask(), getV1VectorType().getDimSize(0), - getV2VectorType().getDimSize(0))) + // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2. + if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0))) return getV2(); - Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2(); - if (!lhs || !rhs) + Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2(); + if (!v1Attr || !v2Attr) return {}; - auto lhsType = - llvm::cast(llvm::cast(lhs).getType()); // Only support 1-D for now to avoid complicated n-D DenseElementsAttr // manipulation. - if (lhsType.getRank() != 1) + if (v1Type.getRank() != 1) return {}; - int64_t lhsSize = lhsType.getDimSize(0); + + int64_t v1Size = v1Type.getDimSize(0); SmallVector results; - auto lhsElements = llvm::cast(lhs).getValues(); - auto rhsElements = llvm::cast(rhs).getValues(); - for (int64_t i : this->getMask()) { - if (i >= lhsSize) { - results.push_back(rhsElements[i - lhsSize]); + auto v1Elements = cast(v1Attr).getValues(); + auto v2Elements = cast(v2Attr).getValues(); + for (int64_t maskIdx : mask) { + Attribute indexedElm; + // Select v1[0] for poison indices. + // TODO: Return a partial poison vector when supported by the UB dialect. + if (maskIdx == ShuffleOp::kPoisonIndex) { + indexedElm = v1Elements[0]; } else { - results.push_back(lhsElements[i]); + indexedElm = + maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size]; } + + results.push_back(indexedElm); } return DenseElementsAttr::get(getResultVectorType(), results); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index f9e3b772f9f0..6858f0d56e64 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2006,6 +2006,23 @@ func.func @shuffle_1d() -> vector<4xi32> { return %shuffle : vector<4xi32> } +// ----- + +// Check that poison indices pick the first element of the first non-poison +// input vector. That is, %v[0] (i.e., 5) in this test. + +// CHECK-LABEL: func @shuffle_1d_poison_idx +// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32> +// CHECK: return %[[V]] +func.func @shuffle_1d_poison_idx() -> vector<4xi32> { + %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32> + %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +} + +// ----- + // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32> @@ -2013,6 +2030,8 @@ func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vect return %shuffle : vector<1xi32> } +// ----- + // CHECK-LABEL: func @shuffle_fold1 // CHECK: %arg0 : vector<4xi32> func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> { @@ -2020,6 +2039,8 @@ func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi return %shuffle : vector<4xi32> } +// ----- + // CHECK-LABEL: func @shuffle_fold2 // CHECK: %arg1 : vector<2xi32> func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> { @@ -2027,6 +2048,8 @@ func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi return %shuffle : vector<2xi32> } +// ----- + // CHECK-LABEL: func @shuffle_fold3 // CHECK: return %arg0 : vector<4x5x6xi32> func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> { @@ -2034,6 +2057,8 @@ func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve return %shuffle : vector<4x5x6xi32> } +// ----- + // CHECK-LABEL: func @shuffle_fold4 // CHECK: return %arg1 : vector<2x5x6xi32> func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> { @@ -2041,6 +2066,8 @@ func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve return %shuffle : vector<2x5x6xi32> } +// ----- + // CHECK-LABEL: func @shuffle_nofold1 // CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32> // CHECK: return %[[V]] -- 2.11.4.GIT