Rename CODE_OWNERS -> Maintainers (#114544)
[llvm-project.git] / mlir / lib / Conversion / VectorToXeGPU / VectorToXeGPU.cpp
blob215e1b1b874520b35ef2d2fe4b8ffeeca3217674
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/TypeSwitch.h"
24 #include <algorithm>
25 #include <optional>
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
32 using namespace mlir;
34 namespace {
36 // Return true if value represents a zero constant.
37 static bool isZeroConstant(Value val) {
38 auto constant = val.getDefiningOp<arith::ConstantOp>();
39 if (!constant)
40 return false;
42 return TypeSwitch<Attribute, bool>(constant.getValue())
43 .Case<FloatAttr>(
44 [](auto floatAttr) { return floatAttr.getValue().isZero(); })
45 .Case<IntegerAttr>(
46 [](auto intAttr) { return intAttr.getValue().isZero(); })
47 .Default([](auto) { return false; });
50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51 Operation *op, VectorType vecTy) {
52 // Validate only vector as the basic vector store and load ops guarantee
53 // XeGPU-compatible memref source.
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
56 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
58 return success();
61 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
62 VectorTransferOpInterface xferOp) {
63 if (xferOp.getMask())
64 return rewriter.notifyMatchFailure(xferOp,
65 "Masked transfer is not supported");
67 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
68 if (!srcTy)
69 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
71 // Perform common data transfer checks.
72 VectorType vecTy = xferOp.getVectorType();
73 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74 return failure();
76 // Validate further transfer op semantics.
77 SmallVector<int64_t> strides;
78 int64_t offset;
79 if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
80 strides.back() != 1)
81 return rewriter.notifyMatchFailure(
82 xferOp, "Buffer must be contiguous in the innermost dimension");
84 unsigned vecRank = vecTy.getRank();
85 AffineMap map = xferOp.getPermutationMap();
86 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
87 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
88 unsigned numInputDims = map.getNumInputs();
89 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
90 auto dim = dyn_cast<AffineDimExpr>(expr);
91 if (dim.getPosition() < (numInputDims - vecRank))
92 return rewriter.notifyMatchFailure(
93 xferOp, "Only the innermost dimensions can be accessed");
96 return success();
99 static xegpu::CreateNdDescOp
100 createNdDescriptor(PatternRewriter &rewriter, Location loc,
101 xegpu::TensorDescType descType, TypedValue<MemRefType> src,
102 Operation::operand_range offsets) {
103 MemRefType srcTy = src.getType();
104 auto [strides, offset] = getStridesAndOffset(srcTy);
106 xegpu::CreateNdDescOp ndDesc;
107 if (srcTy.hasStaticShape()) {
108 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
109 getAsOpFoldResult(offsets));
110 } else {
111 // In case of any dynamic shapes, source's shape and strides have to be
112 // explicitly provided.
113 SmallVector<Value> sourceDims;
114 unsigned srcRank = srcTy.getRank();
115 for (unsigned i = 0; i < srcRank; ++i)
116 sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
118 SmallVector<int64_t> constOffsets;
119 SmallVector<Value> dynOffsets;
120 for (Value offset : offsets) {
121 std::optional<int64_t> staticVal = getConstantIntValue(offset);
122 if (!staticVal)
123 dynOffsets.push_back(offset);
124 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
127 SmallVector<Value> dynShapes;
128 for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
129 if (shape == ShapedType::kDynamic)
130 dynShapes.push_back(sourceDims[idx]);
133 // Compute strides in reverse order.
134 SmallVector<Value> dynStrides;
135 Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
136 // Last stride is guaranteed to be static and unit.
137 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
138 accStride =
139 rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
140 if (strides[i] == ShapedType::kDynamic)
141 dynStrides.push_back(accStride);
143 std::reverse(dynStrides.begin(), dynStrides.end());
145 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
146 loc, descType, src, dynOffsets, dynShapes, dynStrides,
147 DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
148 DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
149 DenseI64ArrayAttr::get(rewriter.getContext(), strides));
152 return ndDesc;
155 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
156 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
158 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
159 PatternRewriter &rewriter) const override {
160 Location loc = readOp.getLoc();
162 if (failed(transferPreconditions(rewriter, readOp)))
163 return failure();
165 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
166 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
167 return rewriter.notifyMatchFailure(
168 readOp, "Unsupported non-zero padded out-of-bounds read");
170 AffineMap readMap = readOp.getPermutationMap();
171 bool isTransposeLoad = !readMap.isMinorIdentity();
173 VectorType vecTy = readOp.getVectorType();
174 Type elementType = vecTy.getElementType();
175 unsigned minTransposeBitWidth = 32;
176 if (isTransposeLoad &&
177 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
178 return rewriter.notifyMatchFailure(
179 readOp, "Unsupported data type for tranposition");
181 // If load is transposed, get the base shape for the tensor descriptor.
182 SmallVector<int64_t> descShape{vecTy.getShape()};
183 if (isTransposeLoad)
184 std::reverse(descShape.begin(), descShape.end());
185 auto descType = xegpu::TensorDescType::get(
186 descShape, elementType, /*array_length=*/1,
187 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
189 xegpu::CreateNdDescOp ndDesc =
190 createNdDescriptor(rewriter, loc, descType,
191 dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
192 readOp.getIndices());
194 DenseI64ArrayAttr transposeAttr =
195 !isTransposeLoad ? nullptr
196 : DenseI64ArrayAttr::get(rewriter.getContext(),
197 ArrayRef<int64_t>{1, 0});
198 // By default, no specific caching policy is assigned.
199 xegpu::CachePolicyAttr hint = nullptr;
200 auto loadOp = rewriter.create<xegpu::LoadNdOp>(
201 loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
202 /*l1_hint=*/hint,
203 /*l2_hint=*/hint, /*l3_hint=*/hint);
204 rewriter.replaceOp(readOp, loadOp);
206 return success();
210 struct TransferWriteLowering
211 : public OpRewritePattern<vector::TransferWriteOp> {
212 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
214 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
215 PatternRewriter &rewriter) const override {
216 Location loc = writeOp.getLoc();
218 if (failed(transferPreconditions(rewriter, writeOp)))
219 return failure();
221 AffineMap map = writeOp.getPermutationMap();
222 if (!map.isMinorIdentity())
223 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
225 VectorType vecTy = writeOp.getVectorType();
226 auto descType = xegpu::TensorDescType::get(
227 vecTy.getShape(), vecTy.getElementType(),
228 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
229 xegpu::MemorySpace::Global);
230 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
231 rewriter, loc, descType,
232 dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
233 writeOp.getIndices());
235 // By default, no specific caching policy is assigned.
236 xegpu::CachePolicyAttr hint = nullptr;
237 auto storeOp =
238 rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
239 /*l1_hint=*/hint,
240 /*l2_hint=*/hint, /*l3_hint=*/hint);
241 rewriter.replaceOp(writeOp, storeOp);
243 return success();
247 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
248 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
250 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
251 PatternRewriter &rewriter) const override {
252 Location loc = loadOp.getLoc();
254 VectorType vecTy = loadOp.getResult().getType();
255 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
256 return failure();
258 auto descType = xegpu::TensorDescType::get(
259 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
260 /*boundary_check=*/true, xegpu::MemorySpace::Global);
261 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
262 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
264 // By default, no specific caching policy is assigned.
265 xegpu::CachePolicyAttr hint = nullptr;
266 auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
267 loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
268 /*l1_hint=*/hint,
269 /*l2_hint=*/hint, /*l3_hint=*/hint);
270 rewriter.replaceOp(loadOp, loadNdOp);
272 return success();
276 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
277 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
279 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
280 PatternRewriter &rewriter) const override {
281 Location loc = storeOp.getLoc();
283 TypedValue<VectorType> vector = storeOp.getValueToStore();
284 VectorType vecTy = vector.getType();
285 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
286 return failure();
288 auto descType =
289 xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
290 /*array_length=*/1, /*boundary_check=*/true,
291 xegpu::MemorySpace::Global);
292 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
293 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
295 // By default, no specific caching policy is assigned.
296 xegpu::CachePolicyAttr hint = nullptr;
297 auto storeNdOp =
298 rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
299 /*l1_hint=*/hint,
300 /*l2_hint=*/hint, /*l3_hint=*/hint);
301 rewriter.replaceOp(storeOp, storeNdOp);
303 return success();
307 struct ConvertVectorToXeGPUPass
308 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
309 void runOnOperation() override {
310 RewritePatternSet patterns(&getContext());
311 populateVectorToXeGPUConversionPatterns(patterns);
312 if (failed(
313 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
314 return signalPassFailure();
318 } // namespace
320 void mlir::populateVectorToXeGPUConversionPatterns(
321 RewritePatternSet &patterns) {
322 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
323 StoreLowering>(patterns.getContext());
326 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
327 return std::make_unique<ConvertVectorToXeGPUPass>();