Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / GPUCommon / OpToFuncCallLowering.h
blob3b94abd88f9ed22f73da28c9864199c74352583c
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
11 #include "mlir/Conversion/LLVMCommon/Pattern.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/Builders.h"
17 namespace mlir {
19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
20 /// `f32ApproxFunc` or `f16Func` depending on the element type and the
21 /// fastMathFlag of that Op. The function declaration is added in case it was
22 /// not added before.
23 ///
24 /// If the input values are of bf16 type (or f16 type if f16Func is empty), the
25 /// value is first casted to f32, the function called and then the result casted
26 /// back.
27 ///
28 /// Example with NVVM:
29 /// %exp_f32 = math.exp %arg_f32 : f32
30 ///
31 /// will be transformed into
32 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
33 ///
34 /// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
35 /// to the approximate calculation function.
36 ///
37 /// Also example with NVVM:
38 /// %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
39 ///
40 /// will be transformed into
41 /// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
42 template <typename SourceOp>
43 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
44 public:
45 explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
46 StringRef f32Func, StringRef f64Func,
47 StringRef f32ApproxFunc, StringRef f16Func)
48 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
49 f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
51 LogicalResult
52 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 using LLVM::LLVMFuncOp;
56 static_assert(
57 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
58 "expected single result op");
60 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
61 SourceOp>::value,
62 "expected op with same operand and result types");
64 if (!op->template getParentOfType<FunctionOpInterface>()) {
65 return rewriter.notifyMatchFailure(
66 op, "expected op to be within a function region");
69 SmallVector<Value, 1> castedOperands;
70 for (Value operand : adaptor.getOperands())
71 castedOperands.push_back(maybeCast(operand, rewriter));
73 Type resultType = castedOperands.front().getType();
74 Type funcType = getFunctionType(resultType, castedOperands);
75 StringRef funcName =
76 getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
77 op.getFastmath());
78 if (funcName.empty())
79 return failure();
81 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
82 auto callOp =
83 rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
85 if (resultType == adaptor.getOperands().front().getType()) {
86 rewriter.replaceOp(op, {callOp.getResult()});
87 return success();
90 Value truncated = rewriter.create<LLVM::FPTruncOp>(
91 op->getLoc(), adaptor.getOperands().front().getType(),
92 callOp.getResult());
93 rewriter.replaceOp(op, {truncated});
94 return success();
97 private:
98 Value maybeCast(Value operand, PatternRewriter &rewriter) const {
99 Type type = operand.getType();
100 if (!isa<Float16Type, BFloat16Type>(type))
101 return operand;
103 // if there's a f16 function, no need to cast f16 values
104 if (!f16Func.empty() && isa<Float16Type>(type))
105 return operand;
107 return rewriter.create<LLVM::FPExtOp>(
108 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
111 Type getFunctionType(Type resultType, ValueRange operands) const {
112 SmallVector<Type> operandTypes(operands.getTypes());
113 return LLVM::LLVMFunctionType::get(resultType, operandTypes);
116 StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
117 if (isa<Float16Type>(type))
118 return f16Func;
119 if (isa<Float32Type>(type)) {
120 if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
121 !f32ApproxFunc.empty())
122 return f32ApproxFunc;
123 else
124 return f32Func;
126 if (isa<Float64Type>(type))
127 return f64Func;
128 return "";
131 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
132 Operation *op) const {
133 using LLVM::LLVMFuncOp;
135 auto funcAttr = StringAttr::get(op->getContext(), funcName);
136 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
137 if (funcOp)
138 return cast<LLVMFuncOp>(*funcOp);
140 mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
141 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
144 const std::string f32Func;
145 const std::string f64Func;
146 const std::string f32ApproxFunc;
147 const std::string f16Func;
150 } // namespace mlir
152 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_