[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / ArithToAMDGPU / ArithToAMDGPU.cpp
blob6b9cbaf57676c263a5c0e3a0873ee408581b0bc3
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 namespace mlir {
25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
29 using namespace mlir;
30 using namespace mlir::amdgpu;
32 namespace {
33 struct ArithToAMDGPUConversionPass final
34 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
35 using impl::ArithToAMDGPUConversionPassBase<
36 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
38 void runOnOperation() override;
41 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
42 using OpRewritePattern::OpRewritePattern;
44 LogicalResult match(arith::ExtFOp op) const override;
45 void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
48 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
49 bool saturateFP8 = false;
50 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
51 Chipset chipset)
52 : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53 chipset(chipset) {}
54 Chipset chipset;
56 LogicalResult match(arith::TruncFOp op) const override;
57 void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
60 struct TruncfToFloat16RewritePattern final
61 : public OpRewritePattern<arith::TruncFOp> {
63 using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
65 LogicalResult match(arith::TruncFOp op) const override;
66 void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
69 } // end namespace
71 static Value castF32To(Type elementType, Value f32, Location loc,
72 PatternRewriter &rewriter) {
73 if (elementType.isF32())
74 return f32;
75 if (elementType.getIntOrFloatBitWidth() < 32)
76 return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
77 if (elementType.getIntOrFloatBitWidth() > 32)
78 return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
79 llvm_unreachable("The only 32-bit float type is f32");
82 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
83 Type inType = op.getIn().getType();
84 if (auto inVecType = dyn_cast<VectorType>(inType)) {
85 if (inVecType.isScalable())
86 return failure();
87 inType = inVecType.getElementType();
89 return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
92 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
93 PatternRewriter &rewriter) const {
94 Location loc = op.getLoc();
95 Value in = op.getIn();
96 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
97 auto inType = dyn_cast<VectorType>(in.getType());
98 if (!inType) {
99 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
100 loc, rewriter.getF32Type(), in, 0);
101 Value result = castF32To(outElemType, asFloat, loc, rewriter);
102 return rewriter.replaceOp(op, result);
104 int64_t numElements = inType.getNumElements();
105 Value zero = rewriter.create<arith::ConstantOp>(
106 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
107 if (inType.getShape().empty()) {
108 Value scalarIn =
109 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
110 // Recurse to send the 0-D vector case to the 1-D vector case
111 Value scalarExt =
112 rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
113 Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
114 ArrayRef<int64_t>{});
115 return rewriter.replaceOp(op, result);
118 VectorType outType = cast<VectorType>(op.getOut().getType());
119 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
120 outType.getElementType());
121 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
123 if (inType.getRank() > 1) {
124 inType = VectorType::get(SmallVector<int64_t>{numElements},
125 inType.getElementType());
126 in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
129 for (int64_t i = 0; i < numElements; i += 4) {
130 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
131 Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
132 loc, in, i, elemsThisOp, 1);
133 for (int64_t j = 0; j < elemsThisOp; ++j) {
134 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
135 loc, rewriter.getF32Type(), inSlice, j);
136 Value asType = castF32To(outElemType, asFloat, loc, rewriter);
137 result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
141 if (inType.getRank() != outType.getRank()) {
142 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
145 rewriter.replaceOp(op, result);
148 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
149 Type type = value.getType();
150 if (type.isF32())
151 return value;
152 if (type.getIntOrFloatBitWidth() < 32)
153 return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
154 if (type.getIntOrFloatBitWidth() > 32)
155 return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
156 llvm_unreachable("The only 32-bit float type is f32");
159 // If `in` is a finite value, clamp it between the maximum and minimum values
160 // of `outElemType` so that subsequent conversion instructions don't
161 // overflow those out-of-range values to NaN. These semantics are commonly
162 // used in machine-learning contexts where failure to clamp would lead to
163 // excessive NaN production.
164 static Value clampInput(PatternRewriter &rewriter, Location loc,
165 Type outElemType, Value source) {
166 Type sourceType = source.getType();
167 const llvm::fltSemantics &sourceSem =
168 cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
169 const llvm::fltSemantics &targetSem =
170 cast<FloatType>(outElemType).getFloatSemantics();
172 APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
173 APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
174 bool ignoredLosesInfo = false;
175 // We can ignore conversion failures here because this conversion promotes
176 // from a smaller type to a larger one - ex. there can be no loss of precision
177 // when casting fp8 to f16.
178 (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
179 (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
181 Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
182 Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
184 Value inf = createScalarOrSplatConstant(
185 rewriter, loc, sourceType,
186 APFloat::getInf(sourceSem, /*Negative=*/false));
187 Value negInf = createScalarOrSplatConstant(
188 rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
189 Value isInf = rewriter.createOrFold<arith::CmpFOp>(
190 loc, arith::CmpFPredicate::OEQ, source, inf);
191 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
192 loc, arith::CmpFPredicate::OEQ, source, negInf);
193 Value isNan = rewriter.createOrFold<arith::CmpFOp>(
194 loc, arith::CmpFPredicate::UNO, source, source);
195 Value isNonFinite = rewriter.create<arith::OrIOp>(
196 loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
198 Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
199 Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
200 Value res =
201 rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
202 return res;
205 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
206 // Only supporting default rounding mode as of now.
207 if (op.getRoundingmodeAttr())
208 return failure();
209 Type outType = op.getOut().getType();
210 if (auto outVecType = dyn_cast<VectorType>(outType)) {
211 if (outVecType.isScalable())
212 return failure();
213 outType = outVecType.getElementType();
215 auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
216 if (inType && inType.getWidth() <= 8 && saturateFP8)
217 // Conversion between 8-bit floats is not supported with truncation enabled.
218 return failure();
219 return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
222 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
223 PatternRewriter &rewriter) const {
224 Location loc = op.getLoc();
225 Value in = op.getIn();
226 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
227 if (saturateFP8)
228 in = clampInput(rewriter, loc, outElemType, in);
229 auto inVectorTy = dyn_cast<VectorType>(in.getType());
230 VectorType truncResType = VectorType::get(4, outElemType);
231 if (!inVectorTy) {
232 Value asFloat = castToF32(in, loc, rewriter);
233 Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
234 loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
235 /*existing=*/nullptr);
236 Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
237 return rewriter.replaceOp(op, result);
239 VectorType outType = cast<VectorType>(op.getOut().getType());
240 int64_t numElements = outType.getNumElements();
241 Value zero = rewriter.create<arith::ConstantOp>(
242 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
243 if (outType.getShape().empty()) {
244 Value scalarIn =
245 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
246 // Recurse to send the 0-D vector case to the 1-D vector case
247 Value scalarTrunc =
248 rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
249 Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
250 ArrayRef<int64_t>{});
251 return rewriter.replaceOp(op, result);
254 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
255 outType.getElementType());
256 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
258 if (inVectorTy.getRank() > 1) {
259 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
260 inVectorTy.getElementType());
261 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
264 for (int64_t i = 0; i < numElements; i += 4) {
265 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
266 Value thisResult = nullptr;
267 for (int64_t j = 0; j < elemsThisOp; j += 2) {
268 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
269 Value asFloatA = castToF32(elemA, loc, rewriter);
270 Value asFloatB = nullptr;
271 if (j + 1 < elemsThisOp) {
272 Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
273 asFloatB = castToF32(elemB, loc, rewriter);
275 thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
276 loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
278 if (elemsThisOp < 4)
279 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
280 loc, thisResult, 0, elemsThisOp, 1);
281 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
282 result, i, 1);
285 if (inVectorTy.getRank() != outType.getRank()) {
286 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
289 rewriter.replaceOp(op, result);
292 LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
293 Type outType = op.getOut().getType();
294 Type inputType = getElementTypeOrSelf(op.getIn());
295 if (auto outVecType = dyn_cast<VectorType>(outType)) {
296 if (outVecType.isScalable())
297 return failure();
298 outType = outVecType.getElementType();
300 return success(outType.isF16() && inputType.isF32());
303 void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
304 PatternRewriter &rewriter) const {
305 Location loc = op.getLoc();
306 Value in = op.getIn();
307 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
308 VectorType truncResType = VectorType::get(2, outElemType);
309 auto inVectorTy = dyn_cast<VectorType>(in.getType());
311 // Handle the case where input type is not a vector type
312 if (!inVectorTy) {
313 auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314 Value asF16s =
315 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316 Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
317 return rewriter.replaceOp(op, result);
319 VectorType outType = cast<VectorType>(op.getOut().getType());
320 int64_t numElements = outType.getNumElements();
321 Value zero = rewriter.createOrFold<arith::ConstantOp>(
322 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
323 Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
325 if (inVectorTy.getRank() > 1) {
326 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
327 inVectorTy.getElementType());
328 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
331 // Handle the vector case. We also handle the (uncommon) case where the vector
332 // length is odd
333 for (int64_t i = 0; i < numElements; i += 2) {
334 int64_t elemsThisOp = std::min(numElements, i + 2) - i;
335 Value thisResult = nullptr;
336 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
337 Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
339 if (elemsThisOp == 2) {
340 elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
343 thisResult =
344 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
345 // Place back the truncated result into the possibly larger vector. If we
346 // are operating on a size 2 vector, these operations should be folded away
347 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
348 loc, thisResult, 0, elemsThisOp, 1);
349 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
350 result, i, 1);
353 if (inVectorTy.getRank() != outType.getRank()) {
354 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
357 rewriter.replaceOp(op, result);
360 void mlir::arith::populateArithToAMDGPUConversionPatterns(
361 RewritePatternSet &patterns, bool convertFP8Arithmetic,
362 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
364 if (convertFP8Arithmetic) {
365 patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
366 patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
367 saturateFP8Truncf, chipset);
369 if (allowPackedF16Rtz)
370 patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
373 void ArithToAMDGPUConversionPass::runOnOperation() {
374 Operation *op = getOperation();
375 MLIRContext *ctx = &getContext();
376 RewritePatternSet patterns(op->getContext());
377 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
378 if (failed(maybeChipset)) {
379 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
380 return signalPassFailure();
383 bool convertFP8Arithmetic =
384 maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
385 arith::populateArithToAMDGPUConversionPatterns(
386 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
387 *maybeChipset);
388 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
389 return signalPassFailure();