1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/Builders.h"
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20 /// `f32ApproxFunc` or `f16Func` depending on the element type and the
21 /// fastMathFlag of that Op. The function declaration is added in case it was
24 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the
25 /// value is first casted to f32, the function called and then the result casted
28 /// Example with NVVM:
29 /// %exp_f32 = math.exp %arg_f32 : f32
31 /// will be transformed into
32 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
34 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
35 /// to the approximate calculation function.
37 /// Also example with NVVM:
38 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
40 /// will be transformed into
41 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
42 template <typename SourceOp
>
43 struct OpToFuncCallLowering
: public ConvertOpToLLVMPattern
<SourceOp
> {
45 explicit OpToFuncCallLowering(const LLVMTypeConverter
&lowering
,
46 StringRef f32Func
, StringRef f64Func
,
47 StringRef f32ApproxFunc
, StringRef f16Func
)
48 : ConvertOpToLLVMPattern
<SourceOp
>(lowering
), f32Func(f32Func
),
49 f64Func(f64Func
), f32ApproxFunc(f32ApproxFunc
), f16Func(f16Func
) {}
52 matchAndRewrite(SourceOp op
, typename
SourceOp::Adaptor adaptor
,
53 ConversionPatternRewriter
&rewriter
) const override
{
54 using LLVM::LLVMFuncOp
;
57 std::is_base_of
<OpTrait::OneResult
<SourceOp
>, SourceOp
>::value
,
58 "expected single result op");
60 static_assert(std::is_base_of
<OpTrait::SameOperandsAndResultType
<SourceOp
>,
62 "expected op with same operand and result types");
64 if (!op
->template getParentOfType
<FunctionOpInterface
>()) {
65 return rewriter
.notifyMatchFailure(
66 op
, "expected op to be within a function region");
69 SmallVector
<Value
, 1> castedOperands
;
70 for (Value operand
: adaptor
.getOperands())
71 castedOperands
.push_back(maybeCast(operand
, rewriter
));
73 Type resultType
= castedOperands
.front().getType();
74 Type funcType
= getFunctionType(resultType
, castedOperands
);
76 getFunctionName(cast
<LLVM::LLVMFunctionType
>(funcType
).getReturnType(),
81 LLVMFuncOp funcOp
= appendOrGetFuncOp(funcName
, funcType
, op
);
83 rewriter
.create
<LLVM::CallOp
>(op
->getLoc(), funcOp
, castedOperands
);
85 if (resultType
== adaptor
.getOperands().front().getType()) {
86 rewriter
.replaceOp(op
, {callOp
.getResult()});
90 Value truncated
= rewriter
.create
<LLVM::FPTruncOp
>(
91 op
->getLoc(), adaptor
.getOperands().front().getType(),
93 rewriter
.replaceOp(op
, {truncated
});
98 Value
maybeCast(Value operand
, PatternRewriter
&rewriter
) const {
99 Type type
= operand
.getType();
100 if (!isa
<Float16Type
, BFloat16Type
>(type
))
103 // if there's a f16 function, no need to cast f16 values
104 if (!f16Func
.empty() && isa
<Float16Type
>(type
))
107 return rewriter
.create
<LLVM::FPExtOp
>(
108 operand
.getLoc(), Float32Type::get(rewriter
.getContext()), operand
);
111 Type
getFunctionType(Type resultType
, ValueRange operands
) const {
112 SmallVector
<Type
> operandTypes(operands
.getTypes());
113 return LLVM::LLVMFunctionType::get(resultType
, operandTypes
);
116 StringRef
getFunctionName(Type type
, arith::FastMathFlags flag
) const {
117 if (isa
<Float16Type
>(type
))
119 if (isa
<Float32Type
>(type
)) {
120 if (((uint32_t)arith::FastMathFlags::afn
& (uint32_t)flag
) &&
121 !f32ApproxFunc
.empty())
122 return f32ApproxFunc
;
126 if (isa
<Float64Type
>(type
))
131 LLVM::LLVMFuncOp
appendOrGetFuncOp(StringRef funcName
, Type funcType
,
132 Operation
*op
) const {
133 using LLVM::LLVMFuncOp
;
135 auto funcAttr
= StringAttr::get(op
->getContext(), funcName
);
136 Operation
*funcOp
= SymbolTable::lookupNearestSymbolFrom(op
, funcAttr
);
138 return cast
<LLVMFuncOp
>(*funcOp
);
140 mlir::OpBuilder
b(op
->getParentOfType
<FunctionOpInterface
>());
141 return b
.create
<LLVMFuncOp
>(op
->getLoc(), funcName
, funcType
);
144 const std::string f32Func
;
145 const std::string f64Func
;
146 const std::string f32ApproxFunc
;
147 const std::string f16Func
;
152 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_