[Offload] Fix offload-info interface
[llvm-project.git] / flang / lib / Optimizer / CodeGen / FIROpPatterns.cpp
blob12021deb4bd97a0dbbe48c1698076f5d6e01edef
1 //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11 //===----------------------------------------------------------------------===//
13 #include "flang/Optimizer/CodeGen/FIROpPatterns.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "llvm/Support/Debug.h"
17 static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context,
18 unsigned addressSpace = 0) {
19 return mlir::LLVM::LLVMPointerType::get(context, addressSpace);
22 static unsigned getTypeDescFieldId(mlir::Type ty) {
23 auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty));
24 return isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
27 namespace fir {
29 ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern(
30 llvm::StringRef rootOpName, mlir::MLIRContext *context,
31 const fir::LLVMTypeConverter &typeConverter,
32 const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit)
33 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
34 options(options) {}
36 // Convert FIR type to LLVM without turning fir.box<T> into memory
37 // reference.
38 mlir::Type
39 ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const {
40 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
41 return lowerTy().convertBoxTypeAsStruct(boxTy);
42 return lowerTy().convertType(firType);
45 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant(
46 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
47 int value) const {
48 mlir::Type i32Ty = rewriter.getI32Type();
49 mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
50 return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
53 mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset(
54 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
55 int offset) const {
56 mlir::Type ity = lowerTy().offsetType();
57 mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset);
58 return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
61 /// Perform an extension or truncation as needed on an integer value. Lowering
62 /// to the specific target may involve some sign-extending or truncation of
63 /// values, particularly to fit them from abstract box types to the
64 /// appropriate reified structures.
65 mlir::Value ConvertFIRToLLVMPattern::integerCast(
66 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
67 mlir::Type ty, mlir::Value val, bool fold) const {
68 auto valTy = val.getType();
69 // If the value was not yet lowered, lower its type so that it can
70 // be used in getPrimitiveTypeSizeInBits.
71 if (!mlir::isa<mlir::IntegerType>(valTy))
72 valTy = convertType(valTy);
73 auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
74 auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
75 if (fold) {
76 if (toSize < fromSize)
77 return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val);
78 if (toSize > fromSize)
79 return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
80 } else {
81 if (toSize < fromSize)
82 return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
83 if (toSize > fromSize)
84 return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
86 return val;
89 fir::ConvertFIRToLLVMPattern::TypePair
90 ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const {
91 mlir::Type llvmBoxTy =
92 lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy));
93 return TypePair{firBoxTy, llvmBoxTy};
96 /// Construct code sequence to extract the specific value from a `fir.box`.
97 mlir::Value ConvertFIRToLLVMPattern::getValueFromBox(
98 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy,
99 mlir::ConversionPatternRewriter &rewriter, int boxValue) const {
100 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
101 auto pty = getLlvmPtrType(resultTy.getContext());
102 auto p = rewriter.create<mlir::LLVM::GEPOp>(
103 loc, pty, boxTy.llvm, box,
104 llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
105 auto fldTy = getBoxEleTy(boxTy.llvm, {boxValue});
106 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, fldTy, p);
107 auto castOp = integerCast(loc, rewriter, resultTy, loadOp);
108 attachTBAATag(loadOp, boxTy.fir, nullptr, p);
109 return castOp;
111 return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue);
114 /// Method to construct code sequence to get the triple for dimension `dim`
115 /// from a box.
116 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
117 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
118 mlir::Value box, mlir::Value dim,
119 mlir::ConversionPatternRewriter &rewriter) const {
120 mlir::Value l0 =
121 loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
122 mlir::Value l1 =
123 loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
124 mlir::Value l2 =
125 loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
126 return {l0, l1, l2};
129 llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox(
130 mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy,
131 mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const {
132 mlir::Value l0 =
133 getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
134 mlir::Value l1 =
135 getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
136 mlir::Value l2 =
137 getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
138 return {l0, l1, l2};
141 mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox(
142 mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim,
143 int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
144 assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) &&
145 "descriptor inquiry with runtime dim can only be done on descriptor "
146 "in memory");
147 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
148 static_cast<int>(kDimsPosInBox), dim, off);
149 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
150 attachTBAATag(loadOp, boxTy.fir, nullptr, p);
151 return loadOp;
154 mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox(
155 mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off,
156 mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const {
157 if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) {
158 mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0,
159 static_cast<int>(kDimsPosInBox), dim, off);
160 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
161 attachTBAATag(loadOp, boxTy.fir, nullptr, p);
162 return loadOp;
164 return rewriter.create<mlir::LLVM::ExtractValueOp>(
165 loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off});
168 mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox(
169 mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim,
170 mlir::ConversionPatternRewriter &rewriter) const {
171 auto idxTy = lowerTy().indexType();
172 return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy,
173 rewriter);
176 /// Read base address from a fir.box. Returned address has type ty.
177 mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox(
178 mlir::Location loc, TypePair boxTy, mlir::Value box,
179 mlir::ConversionPatternRewriter &rewriter) const {
180 mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext());
181 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
184 mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox(
185 mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box,
186 mlir::ConversionPatternRewriter &rewriter) const {
187 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox);
190 /// Read base address from a fir.box. Returned address has type ty.
191 mlir::Value ConvertFIRToLLVMPattern::getRankFromBox(
192 mlir::Location loc, TypePair boxTy, mlir::Value box,
193 mlir::ConversionPatternRewriter &rewriter) const {
194 mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox});
195 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox);
198 /// Read the extra field from a fir.box.
199 mlir::Value ConvertFIRToLLVMPattern::getExtraFromBox(
200 mlir::Location loc, TypePair boxTy, mlir::Value box,
201 mlir::ConversionPatternRewriter &rewriter) const {
202 mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kExtraPosInBox});
203 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kExtraPosInBox);
206 // Get the element type given an LLVM type that is of the form
207 // (array|struct|vector)+ and the provided indexes.
208 mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy(
209 mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const {
210 for (unsigned i : indexes) {
211 if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) {
212 assert(!t.isOpaque() && i < t.getBody().size());
213 type = t.getBody()[i];
214 } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
215 type = t.getElementType();
216 } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) {
217 type = t.getElementType();
218 } else {
219 fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()),
220 "request for invalid box element type");
223 return type;
226 // Return LLVM type of the object described by a fir.box of \p boxType.
227 mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType(
228 mlir::Type boxType) const {
229 mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType);
230 assert(objectType && "boxType must be a box type");
231 return this->convertType(objectType);
234 /// Read the address of the type descriptor from a box.
235 mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress(
236 mlir::Location loc, TypePair boxTy, mlir::Value box,
237 mlir::ConversionPatternRewriter &rewriter) const {
238 unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir);
239 mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext());
240 return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId);
243 // Load the attribute from the \p box and perform a check against \p maskValue
244 // The final comparison is implemented as `(attribute & maskValue) != 0`.
245 mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
246 mlir::Location loc, TypePair boxTy, mlir::Value box,
247 mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const {
248 mlir::Type attrTy = rewriter.getI32Type();
249 mlir::Value attribute =
250 getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox);
251 mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue);
252 auto maskRes =
253 rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask);
254 mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0);
255 return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne,
256 maskRes, c0);
259 mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
260 mlir::Location loc, TypePair boxTy, mlir::Value box,
261 mlir::ConversionPatternRewriter &rewriter) const {
262 auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
263 assert(firBoxType && "must be a BaseBoxType");
264 const mlir::DataLayout &dl = lowerTy().getDataLayout();
265 if (!firBoxType.isAssumedRank())
266 return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
267 fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
268 mlir::Type llvmScalarBoxType =
269 lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
270 llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
271 mlir::Value scalarBoxSize =
272 genConstantOffset(loc, rewriter, scalarBoxSizeCst);
273 mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
274 mlir::Value rank =
275 integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
276 mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
277 llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
278 assert((scalarBoxSizeCst + sizePerDimCst ==
279 dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
280 firBoxType.getBoxTypeWithNewShape(1)))) &&
281 "descriptor layout requires adding padding for dim field");
282 mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
283 mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
284 loc, sizePerDim.getType(), sizePerDim, rank);
285 mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
286 loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
287 return size;
290 // Find the Block in which the alloca should be inserted.
291 // The order to recursively find the proper block:
292 // 1. An OpenMP Op that will be outlined.
293 // 2. An OpenMP or OpenACC Op with one or more regions holding executable code.
294 // 3. A LLVMFuncOp
295 // 4. The first ancestor that is one of the above.
296 mlir::Block *ConvertFIRToLLVMPattern::getBlockForAllocaInsert(
297 mlir::Operation *op, mlir::Region *parentRegion) const {
298 if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op))
299 return iface.getAllocaBlock();
300 if (auto recipeIface = mlir::dyn_cast<mlir::accomp::RecipeInterface>(op))
301 return recipeIface.getAllocaBlock(*parentRegion);
302 if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op))
303 return &llvmFuncOp.front();
305 return getBlockForAllocaInsert(op->getParentOp(), parentRegion);
308 // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the
309 // allocation address space provided for the architecture in the DataLayout
310 // specification. If the address space is different from the devices
311 // program address space we perform a cast. In the case of most architectures
312 // the program and allocation address space will be the default of 0 and no
313 // cast will be emitted.
314 mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType(
315 mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment,
316 mlir::ConversionPatternRewriter &rewriter) const {
317 auto thisPt = rewriter.saveInsertionPoint();
318 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
319 mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent();
320 mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp, parentRegion);
321 rewriter.setInsertionPointToStart(insertBlock);
322 auto size = genI32Constant(loc, rewriter, 1);
323 unsigned allocaAs = getAllocaAddressSpace(rewriter);
324 unsigned programAs = getProgramAddressSpace(rewriter);
326 mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>(
327 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy,
328 size, alignment);
330 // if our allocation address space, is not the same as the program address
331 // space, then we must emit a cast to the program address space before use.
332 // An example case would be on AMDGPU, where the allocation address space is
333 // the numeric value 5 (private), and the program address space is 0
334 // (generic).
335 if (allocaAs != programAs) {
336 al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
337 loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al);
340 rewriter.restoreInsertionPoint(thisPt);
341 return al;
344 unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace(
345 mlir::ConversionPatternRewriter &rewriter) const {
346 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
347 assert(parentOp != nullptr &&
348 "expected insertion block to have parent operation");
349 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
350 if (mlir::Attribute addrSpace =
351 mlir::DataLayout(module).getAllocaMemorySpace())
352 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
353 return defaultAddressSpace;
356 unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
357 mlir::ConversionPatternRewriter &rewriter) const {
358 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
359 assert(parentOp != nullptr &&
360 "expected insertion block to have parent operation");
361 if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
362 if (mlir::Attribute addrSpace =
363 mlir::DataLayout(module).getProgramMemorySpace())
364 return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
365 return defaultAddressSpace;
368 } // namespace fir