1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
11 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
30 using namespace mlir::amdgpu
;
33 struct ArithToAMDGPUConversionPass final
34 : impl::ArithToAMDGPUConversionPassBase
<ArithToAMDGPUConversionPass
> {
35 using impl::ArithToAMDGPUConversionPassBase
<
36 ArithToAMDGPUConversionPass
>::ArithToAMDGPUConversionPassBase
;
38 void runOnOperation() override
;
41 struct ExtFOnFloat8RewritePattern final
: OpRewritePattern
<arith::ExtFOp
> {
42 using OpRewritePattern::OpRewritePattern
;
44 LogicalResult
match(arith::ExtFOp op
) const override
;
45 void rewrite(arith::ExtFOp op
, PatternRewriter
&rewriter
) const override
;
48 struct TruncFToFloat8RewritePattern final
: OpRewritePattern
<arith::TruncFOp
> {
49 bool saturateFP8
= false;
50 TruncFToFloat8RewritePattern(MLIRContext
*ctx
, bool saturateFP8
,
52 : OpRewritePattern::OpRewritePattern(ctx
), saturateFP8(saturateFP8
),
56 LogicalResult
match(arith::TruncFOp op
) const override
;
57 void rewrite(arith::TruncFOp op
, PatternRewriter
&rewriter
) const override
;
60 struct TruncfToFloat16RewritePattern final
61 : public OpRewritePattern
<arith::TruncFOp
> {
63 using OpRewritePattern
<arith::TruncFOp
>::OpRewritePattern
;
65 LogicalResult
match(arith::TruncFOp op
) const override
;
66 void rewrite(arith::TruncFOp op
, PatternRewriter
&rewriter
) const override
;
71 static Value
castF32To(Type elementType
, Value f32
, Location loc
,
72 PatternRewriter
&rewriter
) {
73 if (elementType
.isF32())
75 if (elementType
.getIntOrFloatBitWidth() < 32)
76 return rewriter
.create
<arith::TruncFOp
>(loc
, elementType
, f32
);
77 if (elementType
.getIntOrFloatBitWidth() > 32)
78 return rewriter
.create
<arith::ExtFOp
>(loc
, elementType
, f32
);
79 llvm_unreachable("The only 32-bit float type is f32");
82 LogicalResult
ExtFOnFloat8RewritePattern::match(arith::ExtFOp op
) const {
83 Type inType
= op
.getIn().getType();
84 if (auto inVecType
= dyn_cast
<VectorType
>(inType
)) {
85 if (inVecType
.isScalable())
87 inType
= inVecType
.getElementType();
89 return success(inType
.isFloat8E5M2FNUZ() || inType
.isFloat8E4M3FNUZ());
92 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op
,
93 PatternRewriter
&rewriter
) const {
94 Location loc
= op
.getLoc();
95 Value in
= op
.getIn();
96 Type outElemType
= getElementTypeOrSelf(op
.getOut().getType());
97 auto inType
= dyn_cast
<VectorType
>(in
.getType());
99 Value asFloat
= rewriter
.create
<amdgpu::ExtPackedFp8Op
>(
100 loc
, rewriter
.getF32Type(), in
, 0);
101 Value result
= castF32To(outElemType
, asFloat
, loc
, rewriter
);
102 return rewriter
.replaceOp(op
, result
);
104 int64_t numElements
= inType
.getNumElements();
105 Value zero
= rewriter
.create
<arith::ConstantOp
>(
106 loc
, outElemType
, rewriter
.getFloatAttr(outElemType
, 0.0));
107 if (inType
.getShape().empty()) {
109 rewriter
.create
<vector::ExtractOp
>(loc
, in
, ArrayRef
<int64_t>{});
110 // Recurse to send the 0-D vector case to the 1-D vector case
112 rewriter
.create
<arith::ExtFOp
>(loc
, outElemType
, scalarIn
);
113 Value result
= rewriter
.create
<vector::InsertOp
>(loc
, scalarExt
, zero
,
114 ArrayRef
<int64_t>{});
115 return rewriter
.replaceOp(op
, result
);
118 VectorType outType
= cast
<VectorType
>(op
.getOut().getType());
119 VectorType flatTy
= VectorType::get(SmallVector
<int64_t>{numElements
},
120 outType
.getElementType());
121 Value result
= rewriter
.createOrFold
<vector::SplatOp
>(loc
, flatTy
, zero
);
123 if (inType
.getRank() > 1) {
124 inType
= VectorType::get(SmallVector
<int64_t>{numElements
},
125 inType
.getElementType());
126 in
= rewriter
.create
<vector::ShapeCastOp
>(loc
, inType
, in
);
129 for (int64_t i
= 0; i
< numElements
; i
+= 4) {
130 int64_t elemsThisOp
= std::min(numElements
, i
+ 4) - i
;
131 Value inSlice
= rewriter
.create
<vector::ExtractStridedSliceOp
>(
132 loc
, in
, i
, elemsThisOp
, 1);
133 for (int64_t j
= 0; j
< elemsThisOp
; ++j
) {
134 Value asFloat
= rewriter
.create
<amdgpu::ExtPackedFp8Op
>(
135 loc
, rewriter
.getF32Type(), inSlice
, j
);
136 Value asType
= castF32To(outElemType
, asFloat
, loc
, rewriter
);
137 result
= rewriter
.create
<vector::InsertOp
>(loc
, asType
, result
, i
+ j
);
141 if (inType
.getRank() != outType
.getRank()) {
142 result
= rewriter
.create
<vector::ShapeCastOp
>(loc
, outType
, result
);
145 rewriter
.replaceOp(op
, result
);
148 static Value
castToF32(Value value
, Location loc
, PatternRewriter
&rewriter
) {
149 Type type
= value
.getType();
152 if (type
.getIntOrFloatBitWidth() < 32)
153 return rewriter
.create
<arith::ExtFOp
>(loc
, rewriter
.getF32Type(), value
);
154 if (type
.getIntOrFloatBitWidth() > 32)
155 return rewriter
.create
<arith::TruncFOp
>(loc
, rewriter
.getF32Type(), value
);
156 llvm_unreachable("The only 32-bit float type is f32");
159 // If `in` is a finite value, clamp it between the maximum and minimum values
160 // of `outElemType` so that subsequent conversion instructions don't
161 // overflow those out-of-range values to NaN. These semantics are commonly
162 // used in machine-learning contexts where failure to clamp would lead to
163 // excessive NaN production.
164 static Value
clampInput(PatternRewriter
&rewriter
, Location loc
,
165 Type outElemType
, Value source
) {
166 Type sourceType
= source
.getType();
167 const llvm::fltSemantics
&sourceSem
=
168 cast
<FloatType
>(getElementTypeOrSelf(sourceType
)).getFloatSemantics();
169 const llvm::fltSemantics
&targetSem
=
170 cast
<FloatType
>(outElemType
).getFloatSemantics();
172 APFloat min
= APFloat::getLargest(targetSem
, /*Negative=*/true);
173 APFloat max
= APFloat::getLargest(targetSem
, /*Negative=*/false);
174 bool ignoredLosesInfo
= false;
175 // We can ignore conversion failures here because this conversion promotes
176 // from a smaller type to a larger one - ex. there can be no loss of precision
177 // when casting fp8 to f16.
178 (void)min
.convert(sourceSem
, APFloat::rmNearestTiesToEven
, &ignoredLosesInfo
);
179 (void)max
.convert(sourceSem
, APFloat::rmNearestTiesToEven
, &ignoredLosesInfo
);
181 Value minCst
= createScalarOrSplatConstant(rewriter
, loc
, sourceType
, min
);
182 Value maxCst
= createScalarOrSplatConstant(rewriter
, loc
, sourceType
, max
);
184 Value inf
= createScalarOrSplatConstant(
185 rewriter
, loc
, sourceType
,
186 APFloat::getInf(sourceSem
, /*Negative=*/false));
187 Value negInf
= createScalarOrSplatConstant(
188 rewriter
, loc
, sourceType
, APFloat::getInf(sourceSem
, /*Negative=*/true));
189 Value isInf
= rewriter
.createOrFold
<arith::CmpFOp
>(
190 loc
, arith::CmpFPredicate::OEQ
, source
, inf
);
191 Value isNegInf
= rewriter
.createOrFold
<arith::CmpFOp
>(
192 loc
, arith::CmpFPredicate::OEQ
, source
, negInf
);
193 Value isNan
= rewriter
.createOrFold
<arith::CmpFOp
>(
194 loc
, arith::CmpFPredicate::UNO
, source
, source
);
195 Value isNonFinite
= rewriter
.create
<arith::OrIOp
>(
196 loc
, rewriter
.create
<arith::OrIOp
>(loc
, isInf
, isNegInf
), isNan
);
198 Value clampedBelow
= rewriter
.create
<arith::MaximumFOp
>(loc
, source
, minCst
);
199 Value clamped
= rewriter
.create
<arith::MinimumFOp
>(loc
, clampedBelow
, maxCst
);
201 rewriter
.create
<arith::SelectOp
>(loc
, isNonFinite
, source
, clamped
);
205 LogicalResult
TruncFToFloat8RewritePattern::match(arith::TruncFOp op
) const {
206 // Only supporting default rounding mode as of now.
207 if (op
.getRoundingmodeAttr())
209 Type outType
= op
.getOut().getType();
210 if (auto outVecType
= dyn_cast
<VectorType
>(outType
)) {
211 if (outVecType
.isScalable())
213 outType
= outVecType
.getElementType();
215 auto inType
= dyn_cast
<FloatType
>(getElementTypeOrSelf(op
.getIn().getType()));
216 if (inType
&& inType
.getWidth() <= 8 && saturateFP8
)
217 // Conversion between 8-bit floats is not supported with truncation enabled.
219 return success(outType
.isFloat8E5M2FNUZ() || outType
.isFloat8E4M3FNUZ());
222 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op
,
223 PatternRewriter
&rewriter
) const {
224 Location loc
= op
.getLoc();
225 Value in
= op
.getIn();
226 Type outElemType
= getElementTypeOrSelf(op
.getOut().getType());
228 in
= clampInput(rewriter
, loc
, outElemType
, in
);
229 auto inVectorTy
= dyn_cast
<VectorType
>(in
.getType());
230 VectorType truncResType
= VectorType::get(4, outElemType
);
232 Value asFloat
= castToF32(in
, loc
, rewriter
);
233 Value asF8s
= rewriter
.create
<amdgpu::PackedTrunc2xFp8Op
>(
234 loc
, truncResType
, asFloat
, /*sourceB=*/nullptr, 0,
235 /*existing=*/nullptr);
236 Value result
= rewriter
.create
<vector::ExtractOp
>(loc
, asF8s
, 0);
237 return rewriter
.replaceOp(op
, result
);
239 VectorType outType
= cast
<VectorType
>(op
.getOut().getType());
240 int64_t numElements
= outType
.getNumElements();
241 Value zero
= rewriter
.create
<arith::ConstantOp
>(
242 loc
, outElemType
, rewriter
.getFloatAttr(outElemType
, 0.0));
243 if (outType
.getShape().empty()) {
245 rewriter
.create
<vector::ExtractOp
>(loc
, in
, ArrayRef
<int64_t>{});
246 // Recurse to send the 0-D vector case to the 1-D vector case
248 rewriter
.create
<arith::TruncFOp
>(loc
, outElemType
, scalarIn
);
249 Value result
= rewriter
.create
<vector::InsertOp
>(loc
, scalarTrunc
, zero
,
250 ArrayRef
<int64_t>{});
251 return rewriter
.replaceOp(op
, result
);
254 VectorType flatTy
= VectorType::get(SmallVector
<int64_t>{numElements
},
255 outType
.getElementType());
256 Value result
= rewriter
.createOrFold
<vector::SplatOp
>(loc
, flatTy
, zero
);
258 if (inVectorTy
.getRank() > 1) {
259 inVectorTy
= VectorType::get(SmallVector
<int64_t>{numElements
},
260 inVectorTy
.getElementType());
261 in
= rewriter
.create
<vector::ShapeCastOp
>(loc
, inVectorTy
, in
);
264 for (int64_t i
= 0; i
< numElements
; i
+= 4) {
265 int64_t elemsThisOp
= std::min(numElements
, i
+ 4) - i
;
266 Value thisResult
= nullptr;
267 for (int64_t j
= 0; j
< elemsThisOp
; j
+= 2) {
268 Value elemA
= rewriter
.create
<vector::ExtractOp
>(loc
, in
, i
+ j
);
269 Value asFloatA
= castToF32(elemA
, loc
, rewriter
);
270 Value asFloatB
= nullptr;
271 if (j
+ 1 < elemsThisOp
) {
272 Value elemB
= rewriter
.create
<vector::ExtractOp
>(loc
, in
, i
+ j
+ 1);
273 asFloatB
= castToF32(elemB
, loc
, rewriter
);
275 thisResult
= rewriter
.create
<amdgpu::PackedTrunc2xFp8Op
>(
276 loc
, truncResType
, asFloatA
, asFloatB
, j
/ 2, thisResult
);
279 thisResult
= rewriter
.create
<vector::ExtractStridedSliceOp
>(
280 loc
, thisResult
, 0, elemsThisOp
, 1);
281 result
= rewriter
.create
<vector::InsertStridedSliceOp
>(loc
, thisResult
,
285 if (inVectorTy
.getRank() != outType
.getRank()) {
286 result
= rewriter
.create
<vector::ShapeCastOp
>(loc
, outType
, result
);
289 rewriter
.replaceOp(op
, result
);
292 LogicalResult
TruncfToFloat16RewritePattern::match(arith::TruncFOp op
) const {
293 Type outType
= op
.getOut().getType();
294 Type inputType
= getElementTypeOrSelf(op
.getIn());
295 if (auto outVecType
= dyn_cast
<VectorType
>(outType
)) {
296 if (outVecType
.isScalable())
298 outType
= outVecType
.getElementType();
300 return success(outType
.isF16() && inputType
.isF32());
303 void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op
,
304 PatternRewriter
&rewriter
) const {
305 Location loc
= op
.getLoc();
306 Value in
= op
.getIn();
307 Type outElemType
= getElementTypeOrSelf(op
.getOut().getType());
308 VectorType truncResType
= VectorType::get(2, outElemType
);
309 auto inVectorTy
= dyn_cast
<VectorType
>(in
.getType());
311 // Handle the case where input type is not a vector type
313 auto sourceB
= rewriter
.create
<LLVM::PoisonOp
>(loc
, rewriter
.getF32Type());
315 rewriter
.create
<ROCDL::CvtPkRtz
>(loc
, truncResType
, in
, sourceB
);
316 Value result
= rewriter
.create
<vector::ExtractOp
>(loc
, asF16s
, 0);
317 return rewriter
.replaceOp(op
, result
);
319 VectorType outType
= cast
<VectorType
>(op
.getOut().getType());
320 int64_t numElements
= outType
.getNumElements();
321 Value zero
= rewriter
.createOrFold
<arith::ConstantOp
>(
322 loc
, outElemType
, rewriter
.getFloatAttr(outElemType
, 0.0));
323 Value result
= rewriter
.createOrFold
<vector::SplatOp
>(loc
, outType
, zero
);
325 if (inVectorTy
.getRank() > 1) {
326 inVectorTy
= VectorType::get(SmallVector
<int64_t>{numElements
},
327 inVectorTy
.getElementType());
328 in
= rewriter
.create
<vector::ShapeCastOp
>(loc
, inVectorTy
, in
);
331 // Handle the vector case. We also handle the (uncommon) case where the vector
333 for (int64_t i
= 0; i
< numElements
; i
+= 2) {
334 int64_t elemsThisOp
= std::min(numElements
, i
+ 2) - i
;
335 Value thisResult
= nullptr;
336 Value elemA
= rewriter
.create
<vector::ExtractOp
>(loc
, in
, i
);
337 Value elemB
= rewriter
.create
<LLVM::PoisonOp
>(loc
, rewriter
.getF32Type());
339 if (elemsThisOp
== 2) {
340 elemB
= rewriter
.create
<vector::ExtractOp
>(loc
, in
, i
+ 1);
344 rewriter
.create
<ROCDL::CvtPkRtz
>(loc
, truncResType
, elemA
, elemB
);
345 // Place back the truncated result into the possibly larger vector. If we
346 // are operating on a size 2 vector, these operations should be folded away
347 thisResult
= rewriter
.create
<vector::ExtractStridedSliceOp
>(
348 loc
, thisResult
, 0, elemsThisOp
, 1);
349 result
= rewriter
.create
<vector::InsertStridedSliceOp
>(loc
, thisResult
,
353 if (inVectorTy
.getRank() != outType
.getRank()) {
354 result
= rewriter
.create
<vector::ShapeCastOp
>(loc
, outType
, result
);
357 rewriter
.replaceOp(op
, result
);
360 void mlir::arith::populateArithToAMDGPUConversionPatterns(
361 RewritePatternSet
&patterns
, bool convertFP8Arithmetic
,
362 bool saturateFP8Truncf
, bool allowPackedF16Rtz
, Chipset chipset
) {
364 if (convertFP8Arithmetic
) {
365 patterns
.add
<ExtFOnFloat8RewritePattern
>(patterns
.getContext());
366 patterns
.add
<TruncFToFloat8RewritePattern
>(patterns
.getContext(),
367 saturateFP8Truncf
, chipset
);
369 if (allowPackedF16Rtz
)
370 patterns
.add
<TruncfToFloat16RewritePattern
>(patterns
.getContext());
373 void ArithToAMDGPUConversionPass::runOnOperation() {
374 Operation
*op
= getOperation();
375 MLIRContext
*ctx
= &getContext();
376 RewritePatternSet
patterns(op
->getContext());
377 FailureOr
<amdgpu::Chipset
> maybeChipset
= amdgpu::Chipset::parse(chipset
);
378 if (failed(maybeChipset
)) {
379 emitError(UnknownLoc::get(ctx
), "Invalid chipset name: " + chipset
);
380 return signalPassFailure();
383 bool convertFP8Arithmetic
=
384 maybeChipset
->majorVersion
== 9 && *maybeChipset
>= Chipset(9, 4, 0);
385 arith::populateArithToAMDGPUConversionPatterns(
386 patterns
, convertFP8Arithmetic
, saturateFP8Truncf
, allowPackedF16Rtz
,
388 if (failed(applyPatternsAndFoldGreedily(op
, std::move(patterns
))))
389 return signalPassFailure();