[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Target / LLVMIR / ModuleTranslation.cpp
blob33c85d85a684cc854f249d4d2f7e0dfccc6eca6a
1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
18 #include "LoopAnnotationTranslation.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
22 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
25 #include "mlir/IR/AttrTypeSubElements.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/RegionGraphTraits.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
33 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
35 #include "llvm/ADT/PostOrderIterator.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/CFG.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InlineAsm.h"
45 #include "llvm/IR/IntrinsicsNVPTX.h"
46 #include "llvm/IR/LLVMContext.h"
47 #include "llvm/IR/MDBuilder.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Verifier.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/ModuleUtils.h"
53 #include <optional>
55 using namespace mlir;
56 using namespace mlir::LLVM;
57 using namespace mlir::LLVM::detail;
59 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
61 /// Translates the given data layout spec attribute to the LLVM IR data layout.
62 /// Only integer, float, pointer and endianness entries are currently supported.
63 static FailureOr<llvm::DataLayout>
64 translateDataLayout(DataLayoutSpecInterface attribute,
65 const DataLayout &dataLayout,
66 std::optional<Location> loc = std::nullopt) {
67 if (!loc)
68 loc = UnknownLoc::get(attribute.getContext());
70 // Translate the endianness attribute.
71 std::string llvmDataLayout;
72 llvm::raw_string_ostream layoutStream(llvmDataLayout);
73 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
74 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
75 if (!key)
76 continue;
77 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
78 auto value = cast<StringAttr>(entry.getValue());
79 bool isLittleEndian =
80 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
81 layoutStream << "-" << (isLittleEndian ? "e" : "E");
82 layoutStream.flush();
83 continue;
85 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
86 auto value = cast<IntegerAttr>(entry.getValue());
87 uint64_t space = value.getValue().getZExtValue();
88 // Skip the default address space.
89 if (space == 0)
90 continue;
91 layoutStream << "-A" << space;
92 layoutStream.flush();
93 continue;
95 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
96 auto value = cast<IntegerAttr>(entry.getValue());
97 uint64_t alignment = value.getValue().getZExtValue();
98 // Skip the default stack alignment.
99 if (alignment == 0)
100 continue;
101 layoutStream << "-S" << alignment;
102 layoutStream.flush();
103 continue;
105 emitError(*loc) << "unsupported data layout key " << key;
106 return failure();
109 // Go through the list of entries to check which types are explicitly
110 // specified in entries. Where possible, data layout queries are used instead
111 // of directly inspecting the entries.
112 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
113 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
114 if (!type)
115 continue;
116 // Data layout for the index type is irrelevant at this point.
117 if (isa<IndexType>(type))
118 continue;
119 layoutStream << "-";
120 LogicalResult result =
121 llvm::TypeSwitch<Type, LogicalResult>(type)
122 .Case<IntegerType, Float16Type, Float32Type, Float64Type,
123 Float80Type, Float128Type>([&](Type type) -> LogicalResult {
124 if (auto intType = dyn_cast<IntegerType>(type)) {
125 if (intType.getSignedness() != IntegerType::Signless)
126 return emitError(*loc)
127 << "unsupported data layout for non-signless integer "
128 << intType;
129 layoutStream << "i";
130 } else {
131 layoutStream << "f";
133 unsigned size = dataLayout.getTypeSizeInBits(type);
134 unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
135 unsigned preferred =
136 dataLayout.getTypePreferredAlignment(type) * 8u;
137 layoutStream << size << ":" << abi;
138 if (abi != preferred)
139 layoutStream << ":" << preferred;
140 return success();
142 .Case([&](LLVMPointerType ptrType) {
143 layoutStream << "p" << ptrType.getAddressSpace() << ":";
144 unsigned size = dataLayout.getTypeSizeInBits(type);
145 unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
146 unsigned preferred =
147 dataLayout.getTypePreferredAlignment(type) * 8u;
148 layoutStream << size << ":" << abi << ":" << preferred;
149 if (std::optional<unsigned> index = extractPointerSpecValue(
150 entry.getValue(), PtrDLEntryPos::Index))
151 layoutStream << ":" << *index;
152 return success();
154 .Default([loc](Type type) {
155 return emitError(*loc)
156 << "unsupported type in data layout: " << type;
158 if (failed(result))
159 return failure();
161 layoutStream.flush();
162 StringRef layoutSpec(llvmDataLayout);
163 if (layoutSpec.startswith("-"))
164 layoutSpec = layoutSpec.drop_front();
166 return llvm::DataLayout(layoutSpec);
169 /// Builds a constant of a sequential LLVM type `type`, potentially containing
170 /// other sequential types recursively, from the individual constant values
171 /// provided in `constants`. `shape` contains the number of elements in nested
172 /// sequential types. Reports errors at `loc` and returns nullptr on error.
173 static llvm::Constant *
174 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
175 ArrayRef<int64_t> shape, llvm::Type *type,
176 Location loc) {
177 if (shape.empty()) {
178 llvm::Constant *result = constants.front();
179 constants = constants.drop_front();
180 return result;
183 llvm::Type *elementType;
184 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
185 elementType = arrayTy->getElementType();
186 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
187 elementType = vectorTy->getElementType();
188 } else {
189 emitError(loc) << "expected sequential LLVM types wrapping a scalar";
190 return nullptr;
193 SmallVector<llvm::Constant *, 8> nested;
194 nested.reserve(shape.front());
195 for (int64_t i = 0; i < shape.front(); ++i) {
196 nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
197 elementType, loc));
198 if (!nested.back())
199 return nullptr;
202 if (shape.size() == 1 && type->isVectorTy())
203 return llvm::ConstantVector::get(nested);
204 return llvm::ConstantArray::get(
205 llvm::ArrayType::get(elementType, shape.front()), nested);
208 /// Returns the first non-sequential type nested in sequential types.
209 static llvm::Type *getInnermostElementType(llvm::Type *type) {
210 do {
211 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
212 type = arrayTy->getElementType();
213 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
214 type = vectorTy->getElementType();
215 } else {
216 return type;
218 } while (true);
221 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
222 /// storage if possible. This supports elements attributes of tensor or vector
223 /// type and avoids constructing separate objects for individual values of the
224 /// innermost dimension. Constants for other dimensions are still constructed
225 /// recursively. Returns null if constructing from raw data is not supported for
226 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
227 /// other errors at `loc`.
228 static llvm::Constant *
229 convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
230 llvm::Type *llvmType,
231 const ModuleTranslation &moduleTranslation) {
232 if (!denseElementsAttr)
233 return nullptr;
235 llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
236 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
237 return nullptr;
239 ShapedType type = denseElementsAttr.getType();
240 if (type.getNumElements() == 0)
241 return nullptr;
243 // Check that the raw data size matches what is expected for the scalar size.
244 // TODO: in theory, we could repack the data here to keep constructing from
245 // raw data.
246 // TODO: we may also need to consider endianness when cross-compiling to an
247 // architecture where it is different.
248 unsigned elementByteSize = denseElementsAttr.getRawData().size() /
249 denseElementsAttr.getNumElements();
250 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
251 return nullptr;
253 // Compute the shape of all dimensions but the innermost. Note that the
254 // innermost dimension may be that of the vector element type.
255 bool hasVectorElementType = isa<VectorType>(type.getElementType());
256 unsigned numAggregates =
257 denseElementsAttr.getNumElements() /
258 (hasVectorElementType ? 1
259 : denseElementsAttr.getType().getShape().back());
260 ArrayRef<int64_t> outerShape = type.getShape();
261 if (!hasVectorElementType)
262 outerShape = outerShape.drop_back();
264 // Handle the case of vector splat, LLVM has special support for it.
265 if (denseElementsAttr.isSplat() &&
266 (isa<VectorType>(type) || hasVectorElementType)) {
267 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
268 innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
269 moduleTranslation);
270 llvm::Constant *splatVector =
271 llvm::ConstantDataVector::getSplat(0, splatValue);
272 SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
273 ArrayRef<llvm::Constant *> constantsRef = constants;
274 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
276 if (denseElementsAttr.isSplat())
277 return nullptr;
279 // In case of non-splat, create a constructor for the innermost constant from
280 // a piece of raw data.
281 std::function<llvm::Constant *(StringRef)> buildCstData;
282 if (isa<TensorType>(type)) {
283 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
284 if (vectorElementType && vectorElementType.getRank() == 1) {
285 buildCstData = [&](StringRef data) {
286 return llvm::ConstantDataVector::getRaw(
287 data, vectorElementType.getShape().back(), innermostLLVMType);
289 } else if (!vectorElementType) {
290 buildCstData = [&](StringRef data) {
291 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
292 innermostLLVMType);
295 } else if (isa<VectorType>(type)) {
296 buildCstData = [&](StringRef data) {
297 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
298 innermostLLVMType);
301 if (!buildCstData)
302 return nullptr;
304 // Create innermost constants and defer to the default constant creation
305 // mechanism for other dimensions.
306 SmallVector<llvm::Constant *> constants;
307 unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
308 (innermostLLVMType->getScalarSizeInBits() / 8);
309 constants.reserve(numAggregates);
310 for (unsigned i = 0; i < numAggregates; ++i) {
311 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
312 aggregateSize);
313 constants.push_back(buildCstData(data));
316 ArrayRef<llvm::Constant *> constantsRef = constants;
317 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
320 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
321 /// This currently supports integer, floating point, splat and dense element
322 /// attributes and combinations thereof. Also, an array attribute with two
323 /// elements is supported to represent a complex constant. In case of error,
324 /// report it to `loc` and return nullptr.
325 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
326 llvm::Type *llvmType, Attribute attr, Location loc,
327 const ModuleTranslation &moduleTranslation) {
328 if (!attr)
329 return llvm::UndefValue::get(llvmType);
330 if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
331 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
332 if (!arrayAttr || arrayAttr.size() != 2) {
333 emitError(loc, "expected struct type to be a complex number");
334 return nullptr;
336 llvm::Type *elementType = structType->getElementType(0);
337 llvm::Constant *real =
338 getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
339 if (!real)
340 return nullptr;
341 llvm::Constant *imag =
342 getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
343 if (!imag)
344 return nullptr;
345 return llvm::ConstantStruct::get(structType, {real, imag});
347 // For integer types, we allow a mismatch in sizes as the index type in
348 // MLIR might have a different size than the index type in the LLVM module.
349 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
350 return llvm::ConstantInt::get(
351 llvmType,
352 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
353 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
354 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
355 // Special case for 8-bit floats, which are represented by integers due to
356 // the lack of native fp8 types in LLVM at the moment. Additionally, handle
357 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
358 // to i16.
359 unsigned floatWidth = APFloat::getSizeInBits(sem);
360 if (llvmType->isIntegerTy(floatWidth))
361 return llvm::ConstantInt::get(llvmType,
362 floatAttr.getValue().bitcastToAPInt());
363 if (llvmType !=
364 llvm::Type::getFloatingPointTy(llvmType->getContext(),
365 floatAttr.getValue().getSemantics())) {
366 emitError(loc, "FloatAttr does not match expected type of the constant");
367 return nullptr;
369 return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
371 if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
372 return llvm::ConstantExpr::getBitCast(
373 moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
374 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
375 llvm::Type *elementType;
376 uint64_t numElements;
377 bool isScalable = false;
378 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
379 elementType = arrayTy->getElementType();
380 numElements = arrayTy->getNumElements();
381 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
382 elementType = fVectorTy->getElementType();
383 numElements = fVectorTy->getNumElements();
384 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
385 elementType = sVectorTy->getElementType();
386 numElements = sVectorTy->getMinNumElements();
387 isScalable = true;
388 } else {
389 llvm_unreachable("unrecognized constant vector type");
391 // Splat value is a scalar. Extract it only if the element type is not
392 // another sequence type. The recursion terminates because each step removes
393 // one outer sequential type.
394 bool elementTypeSequential =
395 isa<llvm::ArrayType, llvm::VectorType>(elementType);
396 llvm::Constant *child = getLLVMConstant(
397 elementType,
398 elementTypeSequential ? splatAttr
399 : splatAttr.getSplatValue<Attribute>(),
400 loc, moduleTranslation);
401 if (!child)
402 return nullptr;
403 if (llvmType->isVectorTy())
404 return llvm::ConstantVector::getSplat(
405 llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
406 if (llvmType->isArrayTy()) {
407 auto *arrayType = llvm::ArrayType::get(elementType, numElements);
408 SmallVector<llvm::Constant *, 8> constants(numElements, child);
409 return llvm::ConstantArray::get(arrayType, constants);
413 // Try using raw elements data if possible.
414 if (llvm::Constant *result =
415 convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr),
416 llvmType, moduleTranslation)) {
417 return result;
420 // Fall back to element-by-element construction otherwise.
421 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
422 assert(elementsAttr.getShapedType().hasStaticShape());
423 assert(!elementsAttr.getShapedType().getShape().empty() &&
424 "unexpected empty elements attribute shape");
426 SmallVector<llvm::Constant *, 8> constants;
427 constants.reserve(elementsAttr.getNumElements());
428 llvm::Type *innermostType = getInnermostElementType(llvmType);
429 for (auto n : elementsAttr.getValues<Attribute>()) {
430 constants.push_back(
431 getLLVMConstant(innermostType, n, loc, moduleTranslation));
432 if (!constants.back())
433 return nullptr;
435 ArrayRef<llvm::Constant *> constantsRef = constants;
436 llvm::Constant *result = buildSequentialConstant(
437 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
438 assert(constantsRef.empty() && "did not consume all elemental constants");
439 return result;
442 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
443 return llvm::ConstantDataArray::get(
444 moduleTranslation.getLLVMContext(),
445 ArrayRef<char>{stringAttr.getValue().data(),
446 stringAttr.getValue().size()});
448 emitError(loc, "unsupported constant value");
449 return nullptr;
452 ModuleTranslation::ModuleTranslation(Operation *module,
453 std::unique_ptr<llvm::Module> llvmModule)
454 : mlirModule(module), llvmModule(std::move(llvmModule)),
455 debugTranslation(
456 std::make_unique<DebugTranslation>(module, *this->llvmModule)),
457 loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
458 *this, *this->llvmModule)),
459 typeTranslator(this->llvmModule->getContext()),
460 iface(module->getContext()) {
461 assert(satisfiesLLVMModule(mlirModule) &&
462 "mlirModule should honor LLVM's module semantics.");
465 ModuleTranslation::~ModuleTranslation() {
466 if (ompBuilder)
467 ompBuilder->finalize();
470 void ModuleTranslation::forgetMapping(Region &region) {
471 SmallVector<Region *> toProcess;
472 toProcess.push_back(&region);
473 while (!toProcess.empty()) {
474 Region *current = toProcess.pop_back_val();
475 for (Block &block : *current) {
476 blockMapping.erase(&block);
477 for (Value arg : block.getArguments())
478 valueMapping.erase(arg);
479 for (Operation &op : block) {
480 for (Value value : op.getResults())
481 valueMapping.erase(value);
482 if (op.hasSuccessors())
483 branchMapping.erase(&op);
484 if (isa<LLVM::GlobalOp>(op))
485 globalsMapping.erase(&op);
486 llvm::append_range(
487 toProcess,
488 llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
494 /// Get the SSA value passed to the current block from the terminator operation
495 /// of its predecessor.
496 static Value getPHISourceValue(Block *current, Block *pred,
497 unsigned numArguments, unsigned index) {
498 Operation &terminator = *pred->getTerminator();
499 if (isa<LLVM::BrOp>(terminator))
500 return terminator.getOperand(index);
502 #ifndef NDEBUG
503 llvm::SmallPtrSet<Block *, 4> seenSuccessors;
504 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
505 Block *successor = terminator.getSuccessor(i);
506 auto branch = cast<BranchOpInterface>(terminator);
507 SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
508 assert(
509 (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
510 "successors with arguments in LLVM branches must be different blocks");
511 seenSuccessors.insert(successor);
513 #endif
515 // For instructions that branch based on a condition value, we need to take
516 // the operands for the branch that was taken.
517 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
518 // For conditional branches, we take the operands from either the "true" or
519 // the "false" branch.
520 return condBranchOp.getSuccessor(0) == current
521 ? condBranchOp.getTrueDestOperands()[index]
522 : condBranchOp.getFalseDestOperands()[index];
525 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
526 // For switches, we take the operands from either the default case, or from
527 // the case branch that was taken.
528 if (switchOp.getDefaultDestination() == current)
529 return switchOp.getDefaultOperands()[index];
530 for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
531 if (i.value() == current)
532 return switchOp.getCaseOperands(i.index())[index];
535 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
536 return invokeOp.getNormalDest() == current
537 ? invokeOp.getNormalDestOperands()[index]
538 : invokeOp.getUnwindDestOperands()[index];
541 llvm_unreachable(
542 "only branch, switch or invoke operations can be terminators "
543 "of a block that has successors");
546 /// Connect the PHI nodes to the results of preceding blocks.
547 void mlir::LLVM::detail::connectPHINodes(Region &region,
548 const ModuleTranslation &state) {
549 // Skip the first block, it cannot be branched to and its arguments correspond
550 // to the arguments of the LLVM function.
551 for (Block &bb : llvm::drop_begin(region)) {
552 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
553 auto phis = llvmBB->phis();
554 auto numArguments = bb.getNumArguments();
555 assert(numArguments == std::distance(phis.begin(), phis.end()));
556 for (auto [index, phiNode] : llvm::enumerate(phis)) {
557 for (auto *pred : bb.getPredecessors()) {
558 // Find the LLVM IR block that contains the converted terminator
559 // instruction and use it in the PHI node. Note that this block is not
560 // necessarily the same as state.lookupBlock(pred), some operations
561 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
562 // split the blocks.
563 llvm::Instruction *terminator =
564 state.lookupBranch(pred->getTerminator());
565 assert(terminator && "missing the mapping for a terminator");
566 phiNode.addIncoming(state.lookupValue(getPHISourceValue(
567 &bb, pred, numArguments, index)),
568 terminator->getParent());
574 /// Sort function blocks topologically.
575 SetVector<Block *>
576 mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
577 // For each block that has not been visited yet (i.e. that has no
578 // predecessors), add it to the list as well as its successors.
579 SetVector<Block *> blocks;
580 for (Block &b : region) {
581 if (blocks.count(&b) == 0) {
582 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
583 blocks.insert(traversal.begin(), traversal.end());
586 assert(blocks.size() == region.getBlocks().size() &&
587 "some blocks are not sorted");
589 return blocks;
592 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
593 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
594 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
595 llvm::Module *module = builder.GetInsertBlock()->getModule();
596 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
597 return builder.CreateCall(fn, args);
600 /// Given a single MLIR operation, create the corresponding LLVM IR operation
601 /// using the `builder`.
602 LogicalResult
603 ModuleTranslation::convertOperation(Operation &op,
604 llvm::IRBuilderBase &builder) {
605 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
606 if (!opIface)
607 return op.emitError("cannot be converted to LLVM IR: missing "
608 "`LLVMTranslationDialectInterface` registration for "
609 "dialect for op: ")
610 << op.getName();
612 if (failed(opIface->convertOperation(&op, builder, *this)))
613 return op.emitError("LLVM Translation failed for operation: ")
614 << op.getName();
616 return convertDialectAttributes(&op);
619 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
620 /// to define values corresponding to the MLIR block arguments. These nodes
621 /// are not connected to the source basic blocks, which may not exist yet. Uses
622 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
623 /// been created for `bb` and included in the block mapping. Inserts new
624 /// instructions at the end of the block and leaves `builder` in a state
625 /// suitable for further insertion into the end of the block.
626 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
627 llvm::IRBuilderBase &builder) {
628 builder.SetInsertPoint(lookupBlock(&bb));
629 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
631 // Before traversing operations, make block arguments available through
632 // value remapping and PHI nodes, but do not add incoming edges for the PHI
633 // nodes just yet: those values may be defined by this or following blocks.
634 // This step is omitted if "ignoreArguments" is set. The arguments of the
635 // first block have been already made available through the remapping of
636 // LLVM function arguments.
637 if (!ignoreArguments) {
638 auto predecessors = bb.getPredecessors();
639 unsigned numPredecessors =
640 std::distance(predecessors.begin(), predecessors.end());
641 for (auto arg : bb.getArguments()) {
642 auto wrappedType = arg.getType();
643 if (!isCompatibleType(wrappedType))
644 return emitError(bb.front().getLoc(),
645 "block argument does not have an LLVM type");
646 llvm::Type *type = convertType(wrappedType);
647 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
648 mapValue(arg, phi);
652 // Traverse operations.
653 for (auto &op : bb) {
654 // Set the current debug location within the builder.
655 builder.SetCurrentDebugLocation(
656 debugTranslation->translateLoc(op.getLoc(), subprogram));
658 if (failed(convertOperation(op, builder)))
659 return failure();
661 // Set the branch weight metadata on the translated instruction.
662 if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
663 setBranchWeightsMetadata(iface);
666 return success();
669 /// A helper method to get the single Block in an operation honoring LLVM's
670 /// module requirements.
671 static Block &getModuleBody(Operation *module) {
672 return module->getRegion(0).front();
675 /// A helper method to decide if a constant must not be set as a global variable
676 /// initializer. For an external linkage variable, the variable with an
677 /// initializer is considered externally visible and defined in this module, the
678 /// variable without an initializer is externally available and is defined
679 /// elsewhere.
680 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
681 llvm::Constant *cst) {
682 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
683 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
686 /// Sets the runtime preemption specifier of `gv` to dso_local if
687 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
688 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
689 llvm::GlobalValue *gv) {
690 if (dsoLocalRequested)
691 gv->setDSOLocal(true);
694 /// Create named global variables that correspond to llvm.mlir.global
695 /// definitions. Convert llvm.global_ctors and global_dtors ops.
696 LogicalResult ModuleTranslation::convertGlobals() {
697 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
698 llvm::Type *type = convertType(op.getType());
699 llvm::Constant *cst = nullptr;
700 if (op.getValueOrNull()) {
701 // String attributes are treated separately because they cannot appear as
702 // in-function constants and are thus not supported by getLLVMConstant.
703 if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
704 cst = llvm::ConstantDataArray::getString(
705 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
706 type = cst->getType();
707 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
708 *this))) {
709 return failure();
713 auto linkage = convertLinkageToLLVM(op.getLinkage());
714 auto addrSpace = op.getAddrSpace();
716 // LLVM IR requires constant with linkage other than external or weak
717 // external to have initializers. If MLIR does not provide an initializer,
718 // default to undef.
719 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
720 if (!dropInitializer && !cst)
721 cst = llvm::UndefValue::get(type);
722 else if (dropInitializer && cst)
723 cst = nullptr;
725 auto *var = new llvm::GlobalVariable(
726 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
727 /*InsertBefore=*/nullptr,
728 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
729 : llvm::GlobalValue::NotThreadLocal,
730 addrSpace);
732 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
733 auto selectorOp = cast<ComdatSelectorOp>(
734 SymbolTable::lookupNearestSymbolFrom(op, *comdat));
735 var->setComdat(comdatMapping.lookup(selectorOp));
738 if (op.getUnnamedAddr().has_value())
739 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
741 if (op.getSection().has_value())
742 var->setSection(*op.getSection());
744 addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
746 std::optional<uint64_t> alignment = op.getAlignment();
747 if (alignment.has_value())
748 var->setAlignment(llvm::MaybeAlign(alignment.value()));
750 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
752 globalsMapping.try_emplace(op, var);
755 // Convert global variable bodies. This is done after all global variables
756 // have been created in LLVM IR because a global body may refer to another
757 // global or itself. So all global variables need to be mapped first.
758 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
759 if (Block *initializer = op.getInitializerBlock()) {
760 llvm::IRBuilder<> builder(llvmModule->getContext());
761 for (auto &op : initializer->without_terminator()) {
762 if (failed(convertOperation(op, builder)) ||
763 !isa<llvm::Constant>(lookupValue(op.getResult(0))))
764 return emitError(op.getLoc(), "unemittable constant value");
766 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
767 llvm::Constant *cst =
768 cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
769 auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
770 if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
771 global->setInitializer(cst);
775 // Convert llvm.mlir.global_ctors and dtors.
776 for (Operation &op : getModuleBody(mlirModule)) {
777 auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
778 auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
779 if (!ctorOp && !dtorOp)
780 continue;
781 auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
782 : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
783 auto appendGlobalFn =
784 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
785 for (auto symbolAndPriority : range) {
786 llvm::Function *f = lookupFunction(
787 cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
788 appendGlobalFn(*llvmModule, f,
789 cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
790 /*Data=*/nullptr);
794 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
795 if (failed(convertDialectAttributes(op)))
796 return failure();
798 return success();
801 /// Attempts to add an attribute identified by `key`, optionally with the given
802 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
803 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
804 /// otherwise keep it as a string attribute. Performs additional checks for
805 /// attributes known to have or not have a value in order to avoid assertions
806 /// inside LLVM upon construction.
807 static LogicalResult checkedAddLLVMFnAttribute(Location loc,
808 llvm::Function *llvmFunc,
809 StringRef key,
810 StringRef value = StringRef()) {
811 auto kind = llvm::Attribute::getAttrKindFromName(key);
812 if (kind == llvm::Attribute::None) {
813 llvmFunc->addFnAttr(key, value);
814 return success();
817 if (llvm::Attribute::isIntAttrKind(kind)) {
818 if (value.empty())
819 return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
821 int64_t result;
822 if (!value.getAsInteger(/*Radix=*/0, result))
823 llvmFunc->addFnAttr(
824 llvm::Attribute::get(llvmFunc->getContext(), kind, result));
825 else
826 llvmFunc->addFnAttr(key, value);
827 return success();
830 if (!value.empty())
831 return emitError(loc) << "LLVM attribute '" << key
832 << "' does not expect a value, found '" << value
833 << "'";
835 llvmFunc->addFnAttr(kind);
836 return success();
839 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
840 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
841 /// to be an array attribute containing either string attributes, treated as
842 /// value-less LLVM attributes, or array attributes containing two string
843 /// attributes, with the first string being the name of the corresponding LLVM
844 /// attribute and the second string beings its value. Note that even integer
845 /// attributes are expected to have their values expressed as strings.
846 static LogicalResult
847 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
848 llvm::Function *llvmFunc) {
849 if (!attributes)
850 return success();
852 for (Attribute attr : *attributes) {
853 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
854 if (failed(
855 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
856 return failure();
857 continue;
860 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
861 if (!arrayAttr || arrayAttr.size() != 2)
862 return emitError(loc)
863 << "expected 'passthrough' to contain string or array attributes";
865 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
866 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
867 if (!keyAttr || !valueAttr)
868 return emitError(loc)
869 << "expected arrays within 'passthrough' to contain two strings";
871 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
872 valueAttr.getValue())))
873 return failure();
875 return success();
878 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
879 // Clear the block, branch value mappings, they are only relevant within one
880 // function.
881 blockMapping.clear();
882 valueMapping.clear();
883 branchMapping.clear();
884 llvm::Function *llvmFunc = lookupFunction(func.getName());
886 // Translate the debug information for this function.
887 debugTranslation->translate(func, *llvmFunc);
889 // Add function arguments to the value remapping table.
890 for (auto [mlirArg, llvmArg] :
891 llvm::zip(func.getArguments(), llvmFunc->args()))
892 mapValue(mlirArg, &llvmArg);
894 // Check the personality and set it.
895 if (func.getPersonality()) {
896 llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
897 if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
898 func.getLoc(), *this))
899 llvmFunc->setPersonalityFn(pfunc);
902 if (std::optional<StringRef> section = func.getSection())
903 llvmFunc->setSection(*section);
905 if (func.getArmStreaming())
906 llvmFunc->addFnAttr("aarch64_pstate_sm_enabled");
907 else if (func.getArmLocallyStreaming())
908 llvmFunc->addFnAttr("aarch64_pstate_sm_body");
910 // First, create all blocks so we can jump to them.
911 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
912 for (auto &bb : func) {
913 auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
914 llvmBB->insertInto(llvmFunc);
915 mapBlock(&bb, llvmBB);
918 // Then, convert blocks one by one in topological order to ensure defs are
919 // converted before uses.
920 auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
921 for (Block *bb : blocks) {
922 llvm::IRBuilder<> builder(llvmContext);
923 if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
924 return failure();
927 // After all blocks have been traversed and values mapped, connect the PHI
928 // nodes to the results of preceding blocks.
929 detail::connectPHINodes(func.getBody(), *this);
931 // Finally, convert dialect attributes attached to the function.
932 return convertDialectAttributes(func);
935 LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
936 for (NamedAttribute attribute : op->getDialectAttrs())
937 if (failed(iface.amendOperation(op, attribute, *this)))
938 return failure();
939 return success();
942 /// Converts the function attributes from LLVMFuncOp and attaches them to the
943 /// llvm::Function.
944 static void convertFunctionAttributes(LLVMFuncOp func,
945 llvm::Function *llvmFunc) {
946 if (!func.getMemory())
947 return;
949 MemoryEffectsAttr memEffects = func.getMemoryAttr();
951 // Add memory effects incrementally.
952 llvm::MemoryEffects newMemEffects =
953 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
954 convertModRefInfoToLLVM(memEffects.getArgMem()));
955 newMemEffects |= llvm::MemoryEffects(
956 llvm::MemoryEffects::Location::InaccessibleMem,
957 convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
958 newMemEffects |=
959 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
960 convertModRefInfoToLLVM(memEffects.getOther()));
961 llvmFunc->setMemoryEffects(newMemEffects);
964 llvm::AttrBuilder
965 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
966 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
968 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
969 Attribute attr = paramAttrs.get(mlirName);
970 // Skip attributes that are not present.
971 if (!attr)
972 continue;
974 // NOTE: C++17 does not support capturing structured bindings.
975 llvm::Attribute::AttrKind llvmKindCap = llvmKind;
977 llvm::TypeSwitch<Attribute>(attr)
978 .Case<TypeAttr>([&](auto typeAttr) {
979 attrBuilder.addTypeAttr(llvmKindCap,
980 convertType(typeAttr.getValue()));
982 .Case<IntegerAttr>([&](auto intAttr) {
983 attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
985 .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
988 return attrBuilder;
991 LogicalResult ModuleTranslation::convertFunctionSignatures() {
992 // Declare all functions first because there may be function calls that form a
993 // call graph with cycles, or global initializers that reference functions.
994 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
995 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
996 function.getName(),
997 cast<llvm::FunctionType>(convertType(function.getFunctionType())));
998 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
999 llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
1000 llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv()));
1001 mapFunction(function.getName(), llvmFunc);
1002 addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
1004 // Convert function attributes.
1005 convertFunctionAttributes(function, llvmFunc);
1007 // Convert function_entry_count attribute to metadata.
1008 if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1009 llvmFunc->setEntryCount(entryCount.value());
1011 // Convert result attributes.
1012 if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1013 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1014 llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
1017 // Convert argument attributes.
1018 for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
1019 if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1020 llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
1021 llvmArg.addAttrs(attrBuilder);
1025 // Forward the pass-through attributes to LLVM.
1026 if (failed(forwardPassthroughAttributes(
1027 function.getLoc(), function.getPassthrough(), llvmFunc)))
1028 return failure();
1030 // Convert visibility attribute.
1031 llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
1033 // Convert the comdat attribute.
1034 if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1035 auto selectorOp = cast<ComdatSelectorOp>(
1036 SymbolTable::lookupNearestSymbolFrom(function, *comdat));
1037 llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1040 if (auto gc = function.getGarbageCollector())
1041 llvmFunc->setGC(gc->str());
1043 if (auto unnamedAddr = function.getUnnamedAddr())
1044 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1046 if (auto alignment = function.getAlignment())
1047 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1050 return success();
1053 LogicalResult ModuleTranslation::convertFunctions() {
1054 // Convert functions.
1055 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1056 // Do not convert external functions, but do process dialect attributes
1057 // attached to them.
1058 if (function.isExternal()) {
1059 if (failed(convertDialectAttributes(function)))
1060 return failure();
1061 continue;
1064 if (failed(convertOneFunction(function)))
1065 return failure();
1068 return success();
1071 LogicalResult ModuleTranslation::convertComdats() {
1072 for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
1073 for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1074 llvm::Module *module = getLLVMModule();
1075 if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1076 return emitError(selectorOp.getLoc())
1077 << "comdat selection symbols must be unique even in different "
1078 "comdat regions";
1079 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1080 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1081 comdatMapping.try_emplace(selectorOp, comdat);
1084 return success();
1087 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1088 llvm::Instruction *inst) {
1089 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1090 inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1093 llvm::MDNode *
1094 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1095 auto [scopeIt, scopeInserted] =
1096 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr);
1097 if (!scopeInserted)
1098 return scopeIt->second;
1099 llvm::LLVMContext &ctx = llvmModule->getContext();
1100 // Convert the domain metadata node if necessary.
1101 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1102 aliasScopeAttr.getDomain(), nullptr);
1103 if (insertedDomain) {
1104 llvm::SmallVector<llvm::Metadata *, 2> operands;
1105 // Placeholder for self-reference.
1106 operands.push_back({});
1107 if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1108 operands.push_back(llvm::MDString::get(ctx, description));
1109 domainIt->second = llvm::MDNode::get(ctx, operands);
1110 // Self-reference for uniqueness.
1111 domainIt->second->replaceOperandWith(0, domainIt->second);
1113 // Convert the scope metadata node.
1114 assert(domainIt->second && "Scope's domain should already be valid");
1115 llvm::SmallVector<llvm::Metadata *, 3> operands;
1116 // Placeholder for self-reference.
1117 operands.push_back({});
1118 operands.push_back(domainIt->second);
1119 if (StringAttr description = aliasScopeAttr.getDescription())
1120 operands.push_back(llvm::MDString::get(ctx, description));
1121 scopeIt->second = llvm::MDNode::get(ctx, operands);
1122 // Self-reference for uniqueness.
1123 scopeIt->second->replaceOperandWith(0, scopeIt->second);
1124 return scopeIt->second;
1127 llvm::MDNode *ModuleTranslation::getOrCreateAliasScopes(
1128 ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1129 SmallVector<llvm::Metadata *> nodes;
1130 nodes.reserve(aliasScopeAttrs.size());
1131 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1132 nodes.push_back(getOrCreateAliasScope(aliasScopeAttr));
1133 return llvm::MDNode::get(getLLVMContext(), nodes);
1136 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1137 llvm::Instruction *inst) {
1138 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1139 if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1140 return;
1141 llvm::MDNode *node = getOrCreateAliasScopes(
1142 llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1143 inst->setMetadata(kind, node);
1146 populateScopeMetadata(op.getAliasScopesOrNull(),
1147 llvm::LLVMContext::MD_alias_scope);
1148 populateScopeMetadata(op.getNoAliasScopesOrNull(),
1149 llvm::LLVMContext::MD_noalias);
1152 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1153 return tbaaMetadataMapping.lookup(tbaaAttr);
1156 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1157 llvm::Instruction *inst) {
1158 ArrayAttr tagRefs = op.getTBAATagsOrNull();
1159 if (!tagRefs || tagRefs.empty())
1160 return;
1162 // LLVM IR currently does not support attaching more than one TBAA access tag
1163 // to a memory accessing instruction. It may be useful to support this in
1164 // future, but for the time being just ignore the metadata if MLIR operation
1165 // has multiple access tags.
1166 if (tagRefs.size() > 1) {
1167 op.emitWarning() << "TBAA access tags were not translated, because LLVM "
1168 "IR only supports a single tag per instruction";
1169 return;
1172 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1173 inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1176 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
1177 DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
1178 if (!weightsAttr)
1179 return;
1181 llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
1182 assert(inst && "expected the operation to have a mapping to an instruction");
1183 SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
1184 inst->setMetadata(
1185 llvm::LLVMContext::MD_prof,
1186 llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
1189 LogicalResult ModuleTranslation::createTBAAMetadata() {
1190 llvm::LLVMContext &ctx = llvmModule->getContext();
1191 llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
1193 // Walk the entire module and create all metadata nodes for the TBAA
1194 // attributes. The code below relies on two invariants of the
1195 // `AttrTypeWalker`:
1196 // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1197 // this ensures that any lookups into `tbaaMetadataMapping` for child
1198 // attributes succeed.
1199 // 2. Attributes are only ever visited once: This way we don't leak any
1200 // LLVM metadata instances.
1201 AttrTypeWalker walker;
1202 walker.addWalk([&](TBAARootAttr root) {
1203 tbaaMetadataMapping.insert(
1204 {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))});
1207 walker.addWalk([&](TBAATypeDescriptorAttr descriptor) {
1208 SmallVector<llvm::Metadata *> operands;
1209 operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
1210 for (TBAAMemberAttr member : descriptor.getMembers()) {
1211 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
1212 operands.push_back(llvm::ConstantAsMetadata::get(
1213 llvm::ConstantInt::get(offsetTy, member.getOffset())));
1216 tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
1219 walker.addWalk([&](TBAATagAttr tag) {
1220 SmallVector<llvm::Metadata *> operands;
1222 operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1223 operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1225 operands.push_back(llvm::ConstantAsMetadata::get(
1226 llvm::ConstantInt::get(offsetTy, tag.getOffset())));
1227 if (tag.getConstant())
1228 operands.push_back(
1229 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy, 1)));
1231 tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
1234 mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
1235 if (auto attr = analysisOpInterface.getTBAATagsOrNull())
1236 walker.walk(attr);
1239 return success();
1242 void ModuleTranslation::setLoopMetadata(Operation *op,
1243 llvm::Instruction *inst) {
1244 LoopAnnotationAttr attr =
1245 TypeSwitch<Operation *, LoopAnnotationAttr>(op)
1246 .Case<LLVM::BrOp, LLVM::CondBrOp>(
1247 [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
1248 if (!attr)
1249 return;
1250 llvm::MDNode *loopMD =
1251 loopAnnotationTranslation->translateLoopAnnotation(attr, op);
1252 inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
1255 llvm::Type *ModuleTranslation::convertType(Type type) {
1256 return typeTranslator.translateType(type);
1259 /// A helper to look up remapped operands in the value remapping table.
1260 SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
1261 SmallVector<llvm::Value *> remapped;
1262 remapped.reserve(values.size());
1263 for (Value v : values)
1264 remapped.push_back(lookupValue(v));
1265 return remapped;
1268 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
1269 if (!ompBuilder) {
1270 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
1271 ompBuilder->initialize();
1273 // Flags represented as top-level OpenMP dialect attributes are set in
1274 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
1275 // the default configuration.
1276 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
1277 /* IsTargetDevice = */ false, /* IsGPU = */ false,
1278 /* OpenMPOffloadMandatory = */ false,
1279 /* HasRequiresReverseOffload = */ false,
1280 /* HasRequiresUnifiedAddress = */ false,
1281 /* HasRequiresUnifiedSharedMemory = */ false,
1282 /* HasRequiresDynamicAllocators = */ false));
1284 return ompBuilder.get();
1287 llvm::DILocation *ModuleTranslation::translateLoc(Location loc,
1288 llvm::DILocalScope *scope) {
1289 return debugTranslation->translateLoc(loc, scope);
1292 llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
1293 return debugTranslation->translate(attr);
1296 llvm::NamedMDNode *
1297 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
1298 return llvmModule->getOrInsertNamedMetadata(name);
1301 void ModuleTranslation::StackFrame::anchor() {}
1303 static std::unique_ptr<llvm::Module>
1304 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1305 StringRef name) {
1306 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1307 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1308 if (auto dataLayoutAttr =
1309 m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1310 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
1311 } else {
1312 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1313 if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1314 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1315 llvmDataLayout =
1316 translateDataLayout(spec, DataLayout(iface), m->getLoc());
1318 } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1319 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1320 llvmDataLayout =
1321 translateDataLayout(spec, DataLayout(mod), m->getLoc());
1324 if (failed(llvmDataLayout))
1325 return nullptr;
1326 llvmModule->setDataLayout(*llvmDataLayout);
1328 if (auto targetTripleAttr =
1329 m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1330 llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
1332 // Inject declarations for `malloc` and `free` functions that can be used in
1333 // memref allocation/deallocation coming from standard ops lowering.
1334 llvm::IRBuilder<> builder(llvmContext);
1335 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
1336 builder.getInt64Ty());
1337 llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
1338 builder.getInt8PtrTy());
1340 return llvmModule;
1343 std::unique_ptr<llvm::Module>
1344 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1345 StringRef name) {
1346 if (!satisfiesLLVMModule(module)) {
1347 module->emitOpError("can not be translated to an LLVMIR module");
1348 return nullptr;
1351 std::unique_ptr<llvm::Module> llvmModule =
1352 prepareLLVMModule(module, llvmContext, name);
1353 if (!llvmModule)
1354 return nullptr;
1356 LLVM::ensureDistinctSuccessors(module);
1358 ModuleTranslation translator(module, std::move(llvmModule));
1359 llvm::IRBuilder<> llvmBuilder(llvmContext);
1361 // Convert module before functions and operations inside, so dialect
1362 // attributes can be used to change dialect-specific global configurations via
1363 // `amendOperation()`. These configurations can then influence the translation
1364 // of operations afterwards.
1365 if (failed(translator.convertOperation(*module, llvmBuilder)))
1366 return nullptr;
1368 if (failed(translator.convertComdats()))
1369 return nullptr;
1370 if (failed(translator.convertFunctionSignatures()))
1371 return nullptr;
1372 if (failed(translator.convertGlobals()))
1373 return nullptr;
1374 if (failed(translator.createTBAAMetadata()))
1375 return nullptr;
1377 // Convert other top-level operations if possible.
1378 for (Operation &o : getModuleBody(module).getOperations()) {
1379 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1380 LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
1381 !o.hasTrait<OpTrait::IsTerminator>() &&
1382 failed(translator.convertOperation(o, llvmBuilder))) {
1383 return nullptr;
1387 // Operations in function bodies with symbolic references must be converted
1388 // after the top-level operations they refer to are declared, so we do it
1389 // last.
1390 if (failed(translator.convertFunctions()))
1391 return nullptr;
1393 if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1394 return nullptr;
1396 return std::move(translator.llvmModule);