1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements lowering of vector operations to XeGPU dialect ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/TypeSwitch.h"
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
36 // Return true if value represents a zero constant.
37 static bool isZeroConstant(Value val
) {
38 auto constant
= val
.getDefiningOp
<arith::ConstantOp
>();
42 return TypeSwitch
<Attribute
, bool>(constant
.getValue())
44 [](auto floatAttr
) { return floatAttr
.getValue().isZero(); })
46 [](auto intAttr
) { return intAttr
.getValue().isZero(); })
47 .Default([](auto) { return false; });
50 static LogicalResult
storeLoadPreconditions(PatternRewriter
&rewriter
,
51 Operation
*op
, VectorType vecTy
) {
52 // Validate only vector as the basic vector store and load ops guarantee
53 // XeGPU-compatible memref source.
54 unsigned vecRank
= vecTy
.getRank();
55 if (!(vecRank
== 1 || vecRank
== 2))
56 return rewriter
.notifyMatchFailure(op
, "Expects 1D or 2D vector");
61 static LogicalResult
transferPreconditions(PatternRewriter
&rewriter
,
62 VectorTransferOpInterface xferOp
) {
64 return rewriter
.notifyMatchFailure(xferOp
,
65 "Masked transfer is not supported");
67 auto srcTy
= dyn_cast
<MemRefType
>(xferOp
.getShapedType());
69 return rewriter
.notifyMatchFailure(xferOp
, "Expects memref source");
71 // Perform common data transfer checks.
72 VectorType vecTy
= xferOp
.getVectorType();
73 if (failed(storeLoadPreconditions(rewriter
, xferOp
, vecTy
)))
76 // Validate further transfer op semantics.
77 SmallVector
<int64_t> strides
;
79 if (failed(getStridesAndOffset(srcTy
, strides
, offset
)) ||
81 return rewriter
.notifyMatchFailure(
82 xferOp
, "Buffer must be contiguous in the innermost dimension");
84 unsigned vecRank
= vecTy
.getRank();
85 AffineMap map
= xferOp
.getPermutationMap();
86 if (!map
.isProjectedPermutation(/*allowZeroInResults=*/false))
87 return rewriter
.notifyMatchFailure(xferOp
, "Unsupported permutation map");
88 unsigned numInputDims
= map
.getNumInputs();
89 for (AffineExpr expr
: map
.getResults().take_back(vecRank
)) {
90 auto dim
= dyn_cast
<AffineDimExpr
>(expr
);
91 if (dim
.getPosition() < (numInputDims
- vecRank
))
92 return rewriter
.notifyMatchFailure(
93 xferOp
, "Only the innermost dimensions can be accessed");
99 static xegpu::CreateNdDescOp
100 createNdDescriptor(PatternRewriter
&rewriter
, Location loc
,
101 xegpu::TensorDescType descType
, TypedValue
<MemRefType
> src
,
102 Operation::operand_range offsets
) {
103 MemRefType srcTy
= src
.getType();
104 auto [strides
, offset
] = getStridesAndOffset(srcTy
);
106 xegpu::CreateNdDescOp ndDesc
;
107 if (srcTy
.hasStaticShape()) {
108 ndDesc
= rewriter
.create
<xegpu::CreateNdDescOp
>(loc
, descType
, src
,
109 getAsOpFoldResult(offsets
));
111 // In case of any dynamic shapes, source's shape and strides have to be
112 // explicitly provided.
113 SmallVector
<Value
> sourceDims
;
114 unsigned srcRank
= srcTy
.getRank();
115 for (unsigned i
= 0; i
< srcRank
; ++i
)
116 sourceDims
.push_back(rewriter
.create
<memref::DimOp
>(loc
, src
, i
));
118 SmallVector
<int64_t> constOffsets
;
119 SmallVector
<Value
> dynOffsets
;
120 for (Value offset
: offsets
) {
121 std::optional
<int64_t> staticVal
= getConstantIntValue(offset
);
123 dynOffsets
.push_back(offset
);
124 constOffsets
.push_back(staticVal
.value_or(ShapedType::kDynamic
));
127 SmallVector
<Value
> dynShapes
;
128 for (auto [idx
, shape
] : llvm::enumerate(srcTy
.getShape())) {
129 if (shape
== ShapedType::kDynamic
)
130 dynShapes
.push_back(sourceDims
[idx
]);
133 // Compute strides in reverse order.
134 SmallVector
<Value
> dynStrides
;
135 Value accStride
= rewriter
.create
<arith::ConstantIndexOp
>(loc
, 1);
136 // Last stride is guaranteed to be static and unit.
137 for (int i
= static_cast<int>(strides
.size()) - 2; i
>= 0; --i
) {
139 rewriter
.create
<arith::MulIOp
>(loc
, accStride
, sourceDims
[i
+ 1]);
140 if (strides
[i
] == ShapedType::kDynamic
)
141 dynStrides
.push_back(accStride
);
143 std::reverse(dynStrides
.begin(), dynStrides
.end());
145 ndDesc
= rewriter
.create
<xegpu::CreateNdDescOp
>(
146 loc
, descType
, src
, dynOffsets
, dynShapes
, dynStrides
,
147 DenseI64ArrayAttr::get(rewriter
.getContext(), constOffsets
),
148 DenseI64ArrayAttr::get(rewriter
.getContext(), srcTy
.getShape()),
149 DenseI64ArrayAttr::get(rewriter
.getContext(), strides
));
155 struct TransferReadLowering
: public OpRewritePattern
<vector::TransferReadOp
> {
156 using OpRewritePattern
<vector::TransferReadOp
>::OpRewritePattern
;
158 LogicalResult
matchAndRewrite(vector::TransferReadOp readOp
,
159 PatternRewriter
&rewriter
) const override
{
160 Location loc
= readOp
.getLoc();
162 if (failed(transferPreconditions(rewriter
, readOp
)))
165 bool isOutOfBounds
= readOp
.hasOutOfBoundsDim();
166 if (isOutOfBounds
&& !isZeroConstant(readOp
.getPadding()))
167 return rewriter
.notifyMatchFailure(
168 readOp
, "Unsupported non-zero padded out-of-bounds read");
170 AffineMap readMap
= readOp
.getPermutationMap();
171 bool isTransposeLoad
= !readMap
.isMinorIdentity();
173 VectorType vecTy
= readOp
.getVectorType();
174 Type elementType
= vecTy
.getElementType();
175 unsigned minTransposeBitWidth
= 32;
176 if (isTransposeLoad
&&
177 elementType
.getIntOrFloatBitWidth() < minTransposeBitWidth
)
178 return rewriter
.notifyMatchFailure(
179 readOp
, "Unsupported data type for tranposition");
181 // If load is transposed, get the base shape for the tensor descriptor.
182 SmallVector
<int64_t> descShape
{vecTy
.getShape()};
184 std::reverse(descShape
.begin(), descShape
.end());
185 auto descType
= xegpu::TensorDescType::get(
186 descShape
, elementType
, /*array_length=*/1,
187 /*boundary_check=*/isOutOfBounds
, xegpu::MemorySpace::Global
);
189 xegpu::CreateNdDescOp ndDesc
=
190 createNdDescriptor(rewriter
, loc
, descType
,
191 dyn_cast
<TypedValue
<MemRefType
>>(readOp
.getSource()),
192 readOp
.getIndices());
194 DenseI64ArrayAttr transposeAttr
=
195 !isTransposeLoad
? nullptr
196 : DenseI64ArrayAttr::get(rewriter
.getContext(),
197 ArrayRef
<int64_t>{1, 0});
198 // By default, no specific caching policy is assigned.
199 xegpu::CachePolicyAttr hint
= nullptr;
200 auto loadOp
= rewriter
.create
<xegpu::LoadNdOp
>(
201 loc
, vecTy
, ndDesc
, /*packed=*/nullptr, transposeAttr
,
203 /*l2_hint=*/hint
, /*l3_hint=*/hint
);
204 rewriter
.replaceOp(readOp
, loadOp
);
210 struct TransferWriteLowering
211 : public OpRewritePattern
<vector::TransferWriteOp
> {
212 using OpRewritePattern
<vector::TransferWriteOp
>::OpRewritePattern
;
214 LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp
,
215 PatternRewriter
&rewriter
) const override
{
216 Location loc
= writeOp
.getLoc();
218 if (failed(transferPreconditions(rewriter
, writeOp
)))
221 AffineMap map
= writeOp
.getPermutationMap();
222 if (!map
.isMinorIdentity())
223 return rewriter
.notifyMatchFailure(writeOp
, "Expects identity map");
225 VectorType vecTy
= writeOp
.getVectorType();
226 auto descType
= xegpu::TensorDescType::get(
227 vecTy
.getShape(), vecTy
.getElementType(),
228 /*array_length=*/1, /*boundary_check=*/writeOp
.hasOutOfBoundsDim(),
229 xegpu::MemorySpace::Global
);
230 xegpu::CreateNdDescOp ndDesc
= createNdDescriptor(
231 rewriter
, loc
, descType
,
232 dyn_cast
<TypedValue
<MemRefType
>>(writeOp
.getSource()),
233 writeOp
.getIndices());
235 // By default, no specific caching policy is assigned.
236 xegpu::CachePolicyAttr hint
= nullptr;
238 rewriter
.create
<xegpu::StoreNdOp
>(loc
, writeOp
.getVector(), ndDesc
,
240 /*l2_hint=*/hint
, /*l3_hint=*/hint
);
241 rewriter
.replaceOp(writeOp
, storeOp
);
247 struct LoadLowering
: public OpRewritePattern
<vector::LoadOp
> {
248 using OpRewritePattern
<vector::LoadOp
>::OpRewritePattern
;
250 LogicalResult
matchAndRewrite(vector::LoadOp loadOp
,
251 PatternRewriter
&rewriter
) const override
{
252 Location loc
= loadOp
.getLoc();
254 VectorType vecTy
= loadOp
.getResult().getType();
255 if (failed(storeLoadPreconditions(rewriter
, loadOp
, vecTy
)))
258 auto descType
= xegpu::TensorDescType::get(
259 vecTy
.getShape(), vecTy
.getElementType(), /*array_length=*/1,
260 /*boundary_check=*/true, xegpu::MemorySpace::Global
);
261 xegpu::CreateNdDescOp ndDesc
= createNdDescriptor(
262 rewriter
, loc
, descType
, loadOp
.getBase(), loadOp
.getIndices());
264 // By default, no specific caching policy is assigned.
265 xegpu::CachePolicyAttr hint
= nullptr;
266 auto loadNdOp
= rewriter
.create
<xegpu::LoadNdOp
>(
267 loc
, vecTy
, ndDesc
, /*packed=*/nullptr, /*transpose=*/nullptr,
269 /*l2_hint=*/hint
, /*l3_hint=*/hint
);
270 rewriter
.replaceOp(loadOp
, loadNdOp
);
276 struct StoreLowering
: public OpRewritePattern
<vector::StoreOp
> {
277 using OpRewritePattern
<vector::StoreOp
>::OpRewritePattern
;
279 LogicalResult
matchAndRewrite(vector::StoreOp storeOp
,
280 PatternRewriter
&rewriter
) const override
{
281 Location loc
= storeOp
.getLoc();
283 TypedValue
<VectorType
> vector
= storeOp
.getValueToStore();
284 VectorType vecTy
= vector
.getType();
285 if (failed(storeLoadPreconditions(rewriter
, storeOp
, vecTy
)))
289 xegpu::TensorDescType::get(vecTy
.getShape(), vecTy
.getElementType(),
290 /*array_length=*/1, /*boundary_check=*/true,
291 xegpu::MemorySpace::Global
);
292 xegpu::CreateNdDescOp ndDesc
= createNdDescriptor(
293 rewriter
, loc
, descType
, storeOp
.getBase(), storeOp
.getIndices());
295 // By default, no specific caching policy is assigned.
296 xegpu::CachePolicyAttr hint
= nullptr;
298 rewriter
.create
<xegpu::StoreNdOp
>(loc
, vector
, ndDesc
,
300 /*l2_hint=*/hint
, /*l3_hint=*/hint
);
301 rewriter
.replaceOp(storeOp
, storeNdOp
);
307 struct ConvertVectorToXeGPUPass
308 : public impl::ConvertVectorToXeGPUBase
<ConvertVectorToXeGPUPass
> {
309 void runOnOperation() override
{
310 RewritePatternSet
patterns(&getContext());
311 populateVectorToXeGPUConversionPatterns(patterns
);
313 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
))))
314 return signalPassFailure();
320 void mlir::populateVectorToXeGPUConversionPatterns(
321 RewritePatternSet
&patterns
) {
322 patterns
.add
<TransferReadLowering
, TransferWriteLowering
, LoadLowering
,
323 StoreLowering
>(patterns
.getContext());
326 std::unique_ptr
<Pass
> mlir::createConvertVectorToXeGPUPass() {
327 return std::make_unique
<ConvertVectorToXeGPUPass
>();