1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
18 #include "LoopAnnotationTranslation.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
22 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
25 #include "mlir/IR/AttrTypeSubElements.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/RegionGraphTraits.h"
30 #include "mlir/Support/LLVM.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
33 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
35 #include "llvm/ADT/PostOrderIterator.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/CFG.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InlineAsm.h"
45 #include "llvm/IR/IntrinsicsNVPTX.h"
46 #include "llvm/IR/LLVMContext.h"
47 #include "llvm/IR/MDBuilder.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Verifier.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/ModuleUtils.h"
56 using namespace mlir::LLVM
;
57 using namespace mlir::LLVM::detail
;
59 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
61 /// Translates the given data layout spec attribute to the LLVM IR data layout.
62 /// Only integer, float, pointer and endianness entries are currently supported.
63 static FailureOr
<llvm::DataLayout
>
64 translateDataLayout(DataLayoutSpecInterface attribute
,
65 const DataLayout
&dataLayout
,
66 std::optional
<Location
> loc
= std::nullopt
) {
68 loc
= UnknownLoc::get(attribute
.getContext());
70 // Translate the endianness attribute.
71 std::string llvmDataLayout
;
72 llvm::raw_string_ostream
layoutStream(llvmDataLayout
);
73 for (DataLayoutEntryInterface entry
: attribute
.getEntries()) {
74 auto key
= llvm::dyn_cast_if_present
<StringAttr
>(entry
.getKey());
77 if (key
.getValue() == DLTIDialect::kDataLayoutEndiannessKey
) {
78 auto value
= cast
<StringAttr
>(entry
.getValue());
80 value
.getValue() == DLTIDialect::kDataLayoutEndiannessLittle
;
81 layoutStream
<< "-" << (isLittleEndian
? "e" : "E");
85 if (key
.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey
) {
86 auto value
= cast
<IntegerAttr
>(entry
.getValue());
87 uint64_t space
= value
.getValue().getZExtValue();
88 // Skip the default address space.
91 layoutStream
<< "-A" << space
;
95 if (key
.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey
) {
96 auto value
= cast
<IntegerAttr
>(entry
.getValue());
97 uint64_t alignment
= value
.getValue().getZExtValue();
98 // Skip the default stack alignment.
101 layoutStream
<< "-S" << alignment
;
102 layoutStream
.flush();
105 emitError(*loc
) << "unsupported data layout key " << key
;
109 // Go through the list of entries to check which types are explicitly
110 // specified in entries. Where possible, data layout queries are used instead
111 // of directly inspecting the entries.
112 for (DataLayoutEntryInterface entry
: attribute
.getEntries()) {
113 auto type
= llvm::dyn_cast_if_present
<Type
>(entry
.getKey());
116 // Data layout for the index type is irrelevant at this point.
117 if (isa
<IndexType
>(type
))
120 LogicalResult result
=
121 llvm::TypeSwitch
<Type
, LogicalResult
>(type
)
122 .Case
<IntegerType
, Float16Type
, Float32Type
, Float64Type
,
123 Float80Type
, Float128Type
>([&](Type type
) -> LogicalResult
{
124 if (auto intType
= dyn_cast
<IntegerType
>(type
)) {
125 if (intType
.getSignedness() != IntegerType::Signless
)
126 return emitError(*loc
)
127 << "unsupported data layout for non-signless integer "
133 unsigned size
= dataLayout
.getTypeSizeInBits(type
);
134 unsigned abi
= dataLayout
.getTypeABIAlignment(type
) * 8u;
136 dataLayout
.getTypePreferredAlignment(type
) * 8u;
137 layoutStream
<< size
<< ":" << abi
;
138 if (abi
!= preferred
)
139 layoutStream
<< ":" << preferred
;
142 .Case([&](LLVMPointerType ptrType
) {
143 layoutStream
<< "p" << ptrType
.getAddressSpace() << ":";
144 unsigned size
= dataLayout
.getTypeSizeInBits(type
);
145 unsigned abi
= dataLayout
.getTypeABIAlignment(type
) * 8u;
147 dataLayout
.getTypePreferredAlignment(type
) * 8u;
148 layoutStream
<< size
<< ":" << abi
<< ":" << preferred
;
149 if (std::optional
<unsigned> index
= extractPointerSpecValue(
150 entry
.getValue(), PtrDLEntryPos::Index
))
151 layoutStream
<< ":" << *index
;
154 .Default([loc
](Type type
) {
155 return emitError(*loc
)
156 << "unsupported type in data layout: " << type
;
161 layoutStream
.flush();
162 StringRef
layoutSpec(llvmDataLayout
);
163 if (layoutSpec
.startswith("-"))
164 layoutSpec
= layoutSpec
.drop_front();
166 return llvm::DataLayout(layoutSpec
);
169 /// Builds a constant of a sequential LLVM type `type`, potentially containing
170 /// other sequential types recursively, from the individual constant values
171 /// provided in `constants`. `shape` contains the number of elements in nested
172 /// sequential types. Reports errors at `loc` and returns nullptr on error.
173 static llvm::Constant
*
174 buildSequentialConstant(ArrayRef
<llvm::Constant
*> &constants
,
175 ArrayRef
<int64_t> shape
, llvm::Type
*type
,
178 llvm::Constant
*result
= constants
.front();
179 constants
= constants
.drop_front();
183 llvm::Type
*elementType
;
184 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(type
)) {
185 elementType
= arrayTy
->getElementType();
186 } else if (auto *vectorTy
= dyn_cast
<llvm::VectorType
>(type
)) {
187 elementType
= vectorTy
->getElementType();
189 emitError(loc
) << "expected sequential LLVM types wrapping a scalar";
193 SmallVector
<llvm::Constant
*, 8> nested
;
194 nested
.reserve(shape
.front());
195 for (int64_t i
= 0; i
< shape
.front(); ++i
) {
196 nested
.push_back(buildSequentialConstant(constants
, shape
.drop_front(),
202 if (shape
.size() == 1 && type
->isVectorTy())
203 return llvm::ConstantVector::get(nested
);
204 return llvm::ConstantArray::get(
205 llvm::ArrayType::get(elementType
, shape
.front()), nested
);
208 /// Returns the first non-sequential type nested in sequential types.
209 static llvm::Type
*getInnermostElementType(llvm::Type
*type
) {
211 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(type
)) {
212 type
= arrayTy
->getElementType();
213 } else if (auto *vectorTy
= dyn_cast
<llvm::VectorType
>(type
)) {
214 type
= vectorTy
->getElementType();
221 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
222 /// storage if possible. This supports elements attributes of tensor or vector
223 /// type and avoids constructing separate objects for individual values of the
224 /// innermost dimension. Constants for other dimensions are still constructed
225 /// recursively. Returns null if constructing from raw data is not supported for
226 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
227 /// other errors at `loc`.
228 static llvm::Constant
*
229 convertDenseElementsAttr(Location loc
, DenseElementsAttr denseElementsAttr
,
230 llvm::Type
*llvmType
,
231 const ModuleTranslation
&moduleTranslation
) {
232 if (!denseElementsAttr
)
235 llvm::Type
*innermostLLVMType
= getInnermostElementType(llvmType
);
236 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType
))
239 ShapedType type
= denseElementsAttr
.getType();
240 if (type
.getNumElements() == 0)
243 // Check that the raw data size matches what is expected for the scalar size.
244 // TODO: in theory, we could repack the data here to keep constructing from
246 // TODO: we may also need to consider endianness when cross-compiling to an
247 // architecture where it is different.
248 unsigned elementByteSize
= denseElementsAttr
.getRawData().size() /
249 denseElementsAttr
.getNumElements();
250 if (8 * elementByteSize
!= innermostLLVMType
->getScalarSizeInBits())
253 // Compute the shape of all dimensions but the innermost. Note that the
254 // innermost dimension may be that of the vector element type.
255 bool hasVectorElementType
= isa
<VectorType
>(type
.getElementType());
256 unsigned numAggregates
=
257 denseElementsAttr
.getNumElements() /
258 (hasVectorElementType
? 1
259 : denseElementsAttr
.getType().getShape().back());
260 ArrayRef
<int64_t> outerShape
= type
.getShape();
261 if (!hasVectorElementType
)
262 outerShape
= outerShape
.drop_back();
264 // Handle the case of vector splat, LLVM has special support for it.
265 if (denseElementsAttr
.isSplat() &&
266 (isa
<VectorType
>(type
) || hasVectorElementType
)) {
267 llvm::Constant
*splatValue
= LLVM::detail::getLLVMConstant(
268 innermostLLVMType
, denseElementsAttr
.getSplatValue
<Attribute
>(), loc
,
270 llvm::Constant
*splatVector
=
271 llvm::ConstantDataVector::getSplat(0, splatValue
);
272 SmallVector
<llvm::Constant
*> constants(numAggregates
, splatVector
);
273 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
274 return buildSequentialConstant(constantsRef
, outerShape
, llvmType
, loc
);
276 if (denseElementsAttr
.isSplat())
279 // In case of non-splat, create a constructor for the innermost constant from
280 // a piece of raw data.
281 std::function
<llvm::Constant
*(StringRef
)> buildCstData
;
282 if (isa
<TensorType
>(type
)) {
283 auto vectorElementType
= dyn_cast
<VectorType
>(type
.getElementType());
284 if (vectorElementType
&& vectorElementType
.getRank() == 1) {
285 buildCstData
= [&](StringRef data
) {
286 return llvm::ConstantDataVector::getRaw(
287 data
, vectorElementType
.getShape().back(), innermostLLVMType
);
289 } else if (!vectorElementType
) {
290 buildCstData
= [&](StringRef data
) {
291 return llvm::ConstantDataArray::getRaw(data
, type
.getShape().back(),
295 } else if (isa
<VectorType
>(type
)) {
296 buildCstData
= [&](StringRef data
) {
297 return llvm::ConstantDataVector::getRaw(data
, type
.getShape().back(),
304 // Create innermost constants and defer to the default constant creation
305 // mechanism for other dimensions.
306 SmallVector
<llvm::Constant
*> constants
;
307 unsigned aggregateSize
= denseElementsAttr
.getType().getShape().back() *
308 (innermostLLVMType
->getScalarSizeInBits() / 8);
309 constants
.reserve(numAggregates
);
310 for (unsigned i
= 0; i
< numAggregates
; ++i
) {
311 StringRef
data(denseElementsAttr
.getRawData().data() + i
* aggregateSize
,
313 constants
.push_back(buildCstData(data
));
316 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
317 return buildSequentialConstant(constantsRef
, outerShape
, llvmType
, loc
);
320 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
321 /// This currently supports integer, floating point, splat and dense element
322 /// attributes and combinations thereof. Also, an array attribute with two
323 /// elements is supported to represent a complex constant. In case of error,
324 /// report it to `loc` and return nullptr.
325 llvm::Constant
*mlir::LLVM::detail::getLLVMConstant(
326 llvm::Type
*llvmType
, Attribute attr
, Location loc
,
327 const ModuleTranslation
&moduleTranslation
) {
329 return llvm::UndefValue::get(llvmType
);
330 if (auto *structType
= dyn_cast
<::llvm::StructType
>(llvmType
)) {
331 auto arrayAttr
= dyn_cast
<ArrayAttr
>(attr
);
332 if (!arrayAttr
|| arrayAttr
.size() != 2) {
333 emitError(loc
, "expected struct type to be a complex number");
336 llvm::Type
*elementType
= structType
->getElementType(0);
337 llvm::Constant
*real
=
338 getLLVMConstant(elementType
, arrayAttr
[0], loc
, moduleTranslation
);
341 llvm::Constant
*imag
=
342 getLLVMConstant(elementType
, arrayAttr
[1], loc
, moduleTranslation
);
345 return llvm::ConstantStruct::get(structType
, {real
, imag
});
347 // For integer types, we allow a mismatch in sizes as the index type in
348 // MLIR might have a different size than the index type in the LLVM module.
349 if (auto intAttr
= dyn_cast
<IntegerAttr
>(attr
))
350 return llvm::ConstantInt::get(
352 intAttr
.getValue().sextOrTrunc(llvmType
->getIntegerBitWidth()));
353 if (auto floatAttr
= dyn_cast
<FloatAttr
>(attr
)) {
354 const llvm::fltSemantics
&sem
= floatAttr
.getValue().getSemantics();
355 // Special case for 8-bit floats, which are represented by integers due to
356 // the lack of native fp8 types in LLVM at the moment. Additionally, handle
357 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
359 unsigned floatWidth
= APFloat::getSizeInBits(sem
);
360 if (llvmType
->isIntegerTy(floatWidth
))
361 return llvm::ConstantInt::get(llvmType
,
362 floatAttr
.getValue().bitcastToAPInt());
364 llvm::Type::getFloatingPointTy(llvmType
->getContext(),
365 floatAttr
.getValue().getSemantics())) {
366 emitError(loc
, "FloatAttr does not match expected type of the constant");
369 return llvm::ConstantFP::get(llvmType
, floatAttr
.getValue());
371 if (auto funcAttr
= dyn_cast
<FlatSymbolRefAttr
>(attr
))
372 return llvm::ConstantExpr::getBitCast(
373 moduleTranslation
.lookupFunction(funcAttr
.getValue()), llvmType
);
374 if (auto splatAttr
= dyn_cast
<SplatElementsAttr
>(attr
)) {
375 llvm::Type
*elementType
;
376 uint64_t numElements
;
377 bool isScalable
= false;
378 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(llvmType
)) {
379 elementType
= arrayTy
->getElementType();
380 numElements
= arrayTy
->getNumElements();
381 } else if (auto *fVectorTy
= dyn_cast
<llvm::FixedVectorType
>(llvmType
)) {
382 elementType
= fVectorTy
->getElementType();
383 numElements
= fVectorTy
->getNumElements();
384 } else if (auto *sVectorTy
= dyn_cast
<llvm::ScalableVectorType
>(llvmType
)) {
385 elementType
= sVectorTy
->getElementType();
386 numElements
= sVectorTy
->getMinNumElements();
389 llvm_unreachable("unrecognized constant vector type");
391 // Splat value is a scalar. Extract it only if the element type is not
392 // another sequence type. The recursion terminates because each step removes
393 // one outer sequential type.
394 bool elementTypeSequential
=
395 isa
<llvm::ArrayType
, llvm::VectorType
>(elementType
);
396 llvm::Constant
*child
= getLLVMConstant(
398 elementTypeSequential
? splatAttr
399 : splatAttr
.getSplatValue
<Attribute
>(),
400 loc
, moduleTranslation
);
403 if (llvmType
->isVectorTy())
404 return llvm::ConstantVector::getSplat(
405 llvm::ElementCount::get(numElements
, /*Scalable=*/isScalable
), child
);
406 if (llvmType
->isArrayTy()) {
407 auto *arrayType
= llvm::ArrayType::get(elementType
, numElements
);
408 SmallVector
<llvm::Constant
*, 8> constants(numElements
, child
);
409 return llvm::ConstantArray::get(arrayType
, constants
);
413 // Try using raw elements data if possible.
414 if (llvm::Constant
*result
=
415 convertDenseElementsAttr(loc
, dyn_cast
<DenseElementsAttr
>(attr
),
416 llvmType
, moduleTranslation
)) {
420 // Fall back to element-by-element construction otherwise.
421 if (auto elementsAttr
= dyn_cast
<ElementsAttr
>(attr
)) {
422 assert(elementsAttr
.getShapedType().hasStaticShape());
423 assert(!elementsAttr
.getShapedType().getShape().empty() &&
424 "unexpected empty elements attribute shape");
426 SmallVector
<llvm::Constant
*, 8> constants
;
427 constants
.reserve(elementsAttr
.getNumElements());
428 llvm::Type
*innermostType
= getInnermostElementType(llvmType
);
429 for (auto n
: elementsAttr
.getValues
<Attribute
>()) {
431 getLLVMConstant(innermostType
, n
, loc
, moduleTranslation
));
432 if (!constants
.back())
435 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
436 llvm::Constant
*result
= buildSequentialConstant(
437 constantsRef
, elementsAttr
.getShapedType().getShape(), llvmType
, loc
);
438 assert(constantsRef
.empty() && "did not consume all elemental constants");
442 if (auto stringAttr
= dyn_cast
<StringAttr
>(attr
)) {
443 return llvm::ConstantDataArray::get(
444 moduleTranslation
.getLLVMContext(),
445 ArrayRef
<char>{stringAttr
.getValue().data(),
446 stringAttr
.getValue().size()});
448 emitError(loc
, "unsupported constant value");
452 ModuleTranslation::ModuleTranslation(Operation
*module
,
453 std::unique_ptr
<llvm::Module
> llvmModule
)
454 : mlirModule(module
), llvmModule(std::move(llvmModule
)),
456 std::make_unique
<DebugTranslation
>(module
, *this->llvmModule
)),
457 loopAnnotationTranslation(std::make_unique
<LoopAnnotationTranslation
>(
458 *this, *this->llvmModule
)),
459 typeTranslator(this->llvmModule
->getContext()),
460 iface(module
->getContext()) {
461 assert(satisfiesLLVMModule(mlirModule
) &&
462 "mlirModule should honor LLVM's module semantics.");
465 ModuleTranslation::~ModuleTranslation() {
467 ompBuilder
->finalize();
470 void ModuleTranslation::forgetMapping(Region
®ion
) {
471 SmallVector
<Region
*> toProcess
;
472 toProcess
.push_back(®ion
);
473 while (!toProcess
.empty()) {
474 Region
*current
= toProcess
.pop_back_val();
475 for (Block
&block
: *current
) {
476 blockMapping
.erase(&block
);
477 for (Value arg
: block
.getArguments())
478 valueMapping
.erase(arg
);
479 for (Operation
&op
: block
) {
480 for (Value value
: op
.getResults())
481 valueMapping
.erase(value
);
482 if (op
.hasSuccessors())
483 branchMapping
.erase(&op
);
484 if (isa
<LLVM::GlobalOp
>(op
))
485 globalsMapping
.erase(&op
);
488 llvm::map_range(op
.getRegions(), [](Region
&r
) { return &r
; }));
494 /// Get the SSA value passed to the current block from the terminator operation
495 /// of its predecessor.
496 static Value
getPHISourceValue(Block
*current
, Block
*pred
,
497 unsigned numArguments
, unsigned index
) {
498 Operation
&terminator
= *pred
->getTerminator();
499 if (isa
<LLVM::BrOp
>(terminator
))
500 return terminator
.getOperand(index
);
503 llvm::SmallPtrSet
<Block
*, 4> seenSuccessors
;
504 for (unsigned i
= 0, e
= terminator
.getNumSuccessors(); i
< e
; ++i
) {
505 Block
*successor
= terminator
.getSuccessor(i
);
506 auto branch
= cast
<BranchOpInterface
>(terminator
);
507 SuccessorOperands successorOperands
= branch
.getSuccessorOperands(i
);
509 (!seenSuccessors
.contains(successor
) || successorOperands
.empty()) &&
510 "successors with arguments in LLVM branches must be different blocks");
511 seenSuccessors
.insert(successor
);
515 // For instructions that branch based on a condition value, we need to take
516 // the operands for the branch that was taken.
517 if (auto condBranchOp
= dyn_cast
<LLVM::CondBrOp
>(terminator
)) {
518 // For conditional branches, we take the operands from either the "true" or
519 // the "false" branch.
520 return condBranchOp
.getSuccessor(0) == current
521 ? condBranchOp
.getTrueDestOperands()[index
]
522 : condBranchOp
.getFalseDestOperands()[index
];
525 if (auto switchOp
= dyn_cast
<LLVM::SwitchOp
>(terminator
)) {
526 // For switches, we take the operands from either the default case, or from
527 // the case branch that was taken.
528 if (switchOp
.getDefaultDestination() == current
)
529 return switchOp
.getDefaultOperands()[index
];
530 for (const auto &i
: llvm::enumerate(switchOp
.getCaseDestinations()))
531 if (i
.value() == current
)
532 return switchOp
.getCaseOperands(i
.index())[index
];
535 if (auto invokeOp
= dyn_cast
<LLVM::InvokeOp
>(terminator
)) {
536 return invokeOp
.getNormalDest() == current
537 ? invokeOp
.getNormalDestOperands()[index
]
538 : invokeOp
.getUnwindDestOperands()[index
];
542 "only branch, switch or invoke operations can be terminators "
543 "of a block that has successors");
546 /// Connect the PHI nodes to the results of preceding blocks.
547 void mlir::LLVM::detail::connectPHINodes(Region
®ion
,
548 const ModuleTranslation
&state
) {
549 // Skip the first block, it cannot be branched to and its arguments correspond
550 // to the arguments of the LLVM function.
551 for (Block
&bb
: llvm::drop_begin(region
)) {
552 llvm::BasicBlock
*llvmBB
= state
.lookupBlock(&bb
);
553 auto phis
= llvmBB
->phis();
554 auto numArguments
= bb
.getNumArguments();
555 assert(numArguments
== std::distance(phis
.begin(), phis
.end()));
556 for (auto [index
, phiNode
] : llvm::enumerate(phis
)) {
557 for (auto *pred
: bb
.getPredecessors()) {
558 // Find the LLVM IR block that contains the converted terminator
559 // instruction and use it in the PHI node. Note that this block is not
560 // necessarily the same as state.lookupBlock(pred), some operations
561 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
563 llvm::Instruction
*terminator
=
564 state
.lookupBranch(pred
->getTerminator());
565 assert(terminator
&& "missing the mapping for a terminator");
566 phiNode
.addIncoming(state
.lookupValue(getPHISourceValue(
567 &bb
, pred
, numArguments
, index
)),
568 terminator
->getParent());
574 /// Sort function blocks topologically.
576 mlir::LLVM::detail::getTopologicallySortedBlocks(Region
®ion
) {
577 // For each block that has not been visited yet (i.e. that has no
578 // predecessors), add it to the list as well as its successors.
579 SetVector
<Block
*> blocks
;
580 for (Block
&b
: region
) {
581 if (blocks
.count(&b
) == 0) {
582 llvm::ReversePostOrderTraversal
<Block
*> traversal(&b
);
583 blocks
.insert(traversal
.begin(), traversal
.end());
586 assert(blocks
.size() == region
.getBlocks().size() &&
587 "some blocks are not sorted");
592 llvm::CallInst
*mlir::LLVM::detail::createIntrinsicCall(
593 llvm::IRBuilderBase
&builder
, llvm::Intrinsic::ID intrinsic
,
594 ArrayRef
<llvm::Value
*> args
, ArrayRef
<llvm::Type
*> tys
) {
595 llvm::Module
*module
= builder
.GetInsertBlock()->getModule();
596 llvm::Function
*fn
= llvm::Intrinsic::getDeclaration(module
, intrinsic
, tys
);
597 return builder
.CreateCall(fn
, args
);
600 /// Given a single MLIR operation, create the corresponding LLVM IR operation
601 /// using the `builder`.
603 ModuleTranslation::convertOperation(Operation
&op
,
604 llvm::IRBuilderBase
&builder
) {
605 const LLVMTranslationDialectInterface
*opIface
= iface
.getInterfaceFor(&op
);
607 return op
.emitError("cannot be converted to LLVM IR: missing "
608 "`LLVMTranslationDialectInterface` registration for "
612 if (failed(opIface
->convertOperation(&op
, builder
, *this)))
613 return op
.emitError("LLVM Translation failed for operation: ")
616 return convertDialectAttributes(&op
);
619 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
620 /// to define values corresponding to the MLIR block arguments. These nodes
621 /// are not connected to the source basic blocks, which may not exist yet. Uses
622 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
623 /// been created for `bb` and included in the block mapping. Inserts new
624 /// instructions at the end of the block and leaves `builder` in a state
625 /// suitable for further insertion into the end of the block.
626 LogicalResult
ModuleTranslation::convertBlock(Block
&bb
, bool ignoreArguments
,
627 llvm::IRBuilderBase
&builder
) {
628 builder
.SetInsertPoint(lookupBlock(&bb
));
629 auto *subprogram
= builder
.GetInsertBlock()->getParent()->getSubprogram();
631 // Before traversing operations, make block arguments available through
632 // value remapping and PHI nodes, but do not add incoming edges for the PHI
633 // nodes just yet: those values may be defined by this or following blocks.
634 // This step is omitted if "ignoreArguments" is set. The arguments of the
635 // first block have been already made available through the remapping of
636 // LLVM function arguments.
637 if (!ignoreArguments
) {
638 auto predecessors
= bb
.getPredecessors();
639 unsigned numPredecessors
=
640 std::distance(predecessors
.begin(), predecessors
.end());
641 for (auto arg
: bb
.getArguments()) {
642 auto wrappedType
= arg
.getType();
643 if (!isCompatibleType(wrappedType
))
644 return emitError(bb
.front().getLoc(),
645 "block argument does not have an LLVM type");
646 llvm::Type
*type
= convertType(wrappedType
);
647 llvm::PHINode
*phi
= builder
.CreatePHI(type
, numPredecessors
);
652 // Traverse operations.
653 for (auto &op
: bb
) {
654 // Set the current debug location within the builder.
655 builder
.SetCurrentDebugLocation(
656 debugTranslation
->translateLoc(op
.getLoc(), subprogram
));
658 if (failed(convertOperation(op
, builder
)))
661 // Set the branch weight metadata on the translated instruction.
662 if (auto iface
= dyn_cast
<BranchWeightOpInterface
>(op
))
663 setBranchWeightsMetadata(iface
);
669 /// A helper method to get the single Block in an operation honoring LLVM's
670 /// module requirements.
671 static Block
&getModuleBody(Operation
*module
) {
672 return module
->getRegion(0).front();
675 /// A helper method to decide if a constant must not be set as a global variable
676 /// initializer. For an external linkage variable, the variable with an
677 /// initializer is considered externally visible and defined in this module, the
678 /// variable without an initializer is externally available and is defined
680 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage
,
681 llvm::Constant
*cst
) {
682 return (linkage
== llvm::GlobalVariable::ExternalLinkage
&& !cst
) ||
683 linkage
== llvm::GlobalVariable::ExternalWeakLinkage
;
686 /// Sets the runtime preemption specifier of `gv` to dso_local if
687 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
688 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested
,
689 llvm::GlobalValue
*gv
) {
690 if (dsoLocalRequested
)
691 gv
->setDSOLocal(true);
694 /// Create named global variables that correspond to llvm.mlir.global
695 /// definitions. Convert llvm.global_ctors and global_dtors ops.
696 LogicalResult
ModuleTranslation::convertGlobals() {
697 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>()) {
698 llvm::Type
*type
= convertType(op
.getType());
699 llvm::Constant
*cst
= nullptr;
700 if (op
.getValueOrNull()) {
701 // String attributes are treated separately because they cannot appear as
702 // in-function constants and are thus not supported by getLLVMConstant.
703 if (auto strAttr
= dyn_cast_or_null
<StringAttr
>(op
.getValueOrNull())) {
704 cst
= llvm::ConstantDataArray::getString(
705 llvmModule
->getContext(), strAttr
.getValue(), /*AddNull=*/false);
706 type
= cst
->getType();
707 } else if (!(cst
= getLLVMConstant(type
, op
.getValueOrNull(), op
.getLoc(),
713 auto linkage
= convertLinkageToLLVM(op
.getLinkage());
714 auto addrSpace
= op
.getAddrSpace();
716 // LLVM IR requires constant with linkage other than external or weak
717 // external to have initializers. If MLIR does not provide an initializer,
719 bool dropInitializer
= shouldDropGlobalInitializer(linkage
, cst
);
720 if (!dropInitializer
&& !cst
)
721 cst
= llvm::UndefValue::get(type
);
722 else if (dropInitializer
&& cst
)
725 auto *var
= new llvm::GlobalVariable(
726 *llvmModule
, type
, op
.getConstant(), linkage
, cst
, op
.getSymName(),
727 /*InsertBefore=*/nullptr,
728 op
.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
729 : llvm::GlobalValue::NotThreadLocal
,
732 if (std::optional
<mlir::SymbolRefAttr
> comdat
= op
.getComdat()) {
733 auto selectorOp
= cast
<ComdatSelectorOp
>(
734 SymbolTable::lookupNearestSymbolFrom(op
, *comdat
));
735 var
->setComdat(comdatMapping
.lookup(selectorOp
));
738 if (op
.getUnnamedAddr().has_value())
739 var
->setUnnamedAddr(convertUnnamedAddrToLLVM(*op
.getUnnamedAddr()));
741 if (op
.getSection().has_value())
742 var
->setSection(*op
.getSection());
744 addRuntimePreemptionSpecifier(op
.getDsoLocal(), var
);
746 std::optional
<uint64_t> alignment
= op
.getAlignment();
747 if (alignment
.has_value())
748 var
->setAlignment(llvm::MaybeAlign(alignment
.value()));
750 var
->setVisibility(convertVisibilityToLLVM(op
.getVisibility_()));
752 globalsMapping
.try_emplace(op
, var
);
755 // Convert global variable bodies. This is done after all global variables
756 // have been created in LLVM IR because a global body may refer to another
757 // global or itself. So all global variables need to be mapped first.
758 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>()) {
759 if (Block
*initializer
= op
.getInitializerBlock()) {
760 llvm::IRBuilder
<> builder(llvmModule
->getContext());
761 for (auto &op
: initializer
->without_terminator()) {
762 if (failed(convertOperation(op
, builder
)) ||
763 !isa
<llvm::Constant
>(lookupValue(op
.getResult(0))))
764 return emitError(op
.getLoc(), "unemittable constant value");
766 ReturnOp ret
= cast
<ReturnOp
>(initializer
->getTerminator());
767 llvm::Constant
*cst
=
768 cast
<llvm::Constant
>(lookupValue(ret
.getOperand(0)));
769 auto *global
= cast
<llvm::GlobalVariable
>(lookupGlobal(op
));
770 if (!shouldDropGlobalInitializer(global
->getLinkage(), cst
))
771 global
->setInitializer(cst
);
775 // Convert llvm.mlir.global_ctors and dtors.
776 for (Operation
&op
: getModuleBody(mlirModule
)) {
777 auto ctorOp
= dyn_cast
<GlobalCtorsOp
>(op
);
778 auto dtorOp
= dyn_cast
<GlobalDtorsOp
>(op
);
779 if (!ctorOp
&& !dtorOp
)
781 auto range
= ctorOp
? llvm::zip(ctorOp
.getCtors(), ctorOp
.getPriorities())
782 : llvm::zip(dtorOp
.getDtors(), dtorOp
.getPriorities());
783 auto appendGlobalFn
=
784 ctorOp
? llvm::appendToGlobalCtors
: llvm::appendToGlobalDtors
;
785 for (auto symbolAndPriority
: range
) {
786 llvm::Function
*f
= lookupFunction(
787 cast
<FlatSymbolRefAttr
>(std::get
<0>(symbolAndPriority
)).getValue());
788 appendGlobalFn(*llvmModule
, f
,
789 cast
<IntegerAttr
>(std::get
<1>(symbolAndPriority
)).getInt(),
794 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>())
795 if (failed(convertDialectAttributes(op
)))
801 /// Attempts to add an attribute identified by `key`, optionally with the given
802 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
803 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
804 /// otherwise keep it as a string attribute. Performs additional checks for
805 /// attributes known to have or not have a value in order to avoid assertions
806 /// inside LLVM upon construction.
807 static LogicalResult
checkedAddLLVMFnAttribute(Location loc
,
808 llvm::Function
*llvmFunc
,
810 StringRef value
= StringRef()) {
811 auto kind
= llvm::Attribute::getAttrKindFromName(key
);
812 if (kind
== llvm::Attribute::None
) {
813 llvmFunc
->addFnAttr(key
, value
);
817 if (llvm::Attribute::isIntAttrKind(kind
)) {
819 return emitError(loc
) << "LLVM attribute '" << key
<< "' expects a value";
822 if (!value
.getAsInteger(/*Radix=*/0, result
))
824 llvm::Attribute::get(llvmFunc
->getContext(), kind
, result
));
826 llvmFunc
->addFnAttr(key
, value
);
831 return emitError(loc
) << "LLVM attribute '" << key
832 << "' does not expect a value, found '" << value
835 llvmFunc
->addFnAttr(kind
);
839 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
840 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
841 /// to be an array attribute containing either string attributes, treated as
842 /// value-less LLVM attributes, or array attributes containing two string
843 /// attributes, with the first string being the name of the corresponding LLVM
844 /// attribute and the second string beings its value. Note that even integer
845 /// attributes are expected to have their values expressed as strings.
847 forwardPassthroughAttributes(Location loc
, std::optional
<ArrayAttr
> attributes
,
848 llvm::Function
*llvmFunc
) {
852 for (Attribute attr
: *attributes
) {
853 if (auto stringAttr
= dyn_cast
<StringAttr
>(attr
)) {
855 checkedAddLLVMFnAttribute(loc
, llvmFunc
, stringAttr
.getValue())))
860 auto arrayAttr
= dyn_cast
<ArrayAttr
>(attr
);
861 if (!arrayAttr
|| arrayAttr
.size() != 2)
862 return emitError(loc
)
863 << "expected 'passthrough' to contain string or array attributes";
865 auto keyAttr
= dyn_cast
<StringAttr
>(arrayAttr
[0]);
866 auto valueAttr
= dyn_cast
<StringAttr
>(arrayAttr
[1]);
867 if (!keyAttr
|| !valueAttr
)
868 return emitError(loc
)
869 << "expected arrays within 'passthrough' to contain two strings";
871 if (failed(checkedAddLLVMFnAttribute(loc
, llvmFunc
, keyAttr
.getValue(),
872 valueAttr
.getValue())))
878 LogicalResult
ModuleTranslation::convertOneFunction(LLVMFuncOp func
) {
879 // Clear the block, branch value mappings, they are only relevant within one
881 blockMapping
.clear();
882 valueMapping
.clear();
883 branchMapping
.clear();
884 llvm::Function
*llvmFunc
= lookupFunction(func
.getName());
886 // Translate the debug information for this function.
887 debugTranslation
->translate(func
, *llvmFunc
);
889 // Add function arguments to the value remapping table.
890 for (auto [mlirArg
, llvmArg
] :
891 llvm::zip(func
.getArguments(), llvmFunc
->args()))
892 mapValue(mlirArg
, &llvmArg
);
894 // Check the personality and set it.
895 if (func
.getPersonality()) {
896 llvm::Type
*ty
= llvm::Type::getInt8PtrTy(llvmFunc
->getContext());
897 if (llvm::Constant
*pfunc
= getLLVMConstant(ty
, func
.getPersonalityAttr(),
898 func
.getLoc(), *this))
899 llvmFunc
->setPersonalityFn(pfunc
);
902 if (std::optional
<StringRef
> section
= func
.getSection())
903 llvmFunc
->setSection(*section
);
905 if (func
.getArmStreaming())
906 llvmFunc
->addFnAttr("aarch64_pstate_sm_enabled");
907 else if (func
.getArmLocallyStreaming())
908 llvmFunc
->addFnAttr("aarch64_pstate_sm_body");
910 // First, create all blocks so we can jump to them.
911 llvm::LLVMContext
&llvmContext
= llvmFunc
->getContext();
912 for (auto &bb
: func
) {
913 auto *llvmBB
= llvm::BasicBlock::Create(llvmContext
);
914 llvmBB
->insertInto(llvmFunc
);
915 mapBlock(&bb
, llvmBB
);
918 // Then, convert blocks one by one in topological order to ensure defs are
919 // converted before uses.
920 auto blocks
= detail::getTopologicallySortedBlocks(func
.getBody());
921 for (Block
*bb
: blocks
) {
922 llvm::IRBuilder
<> builder(llvmContext
);
923 if (failed(convertBlock(*bb
, bb
->isEntryBlock(), builder
)))
927 // After all blocks have been traversed and values mapped, connect the PHI
928 // nodes to the results of preceding blocks.
929 detail::connectPHINodes(func
.getBody(), *this);
931 // Finally, convert dialect attributes attached to the function.
932 return convertDialectAttributes(func
);
935 LogicalResult
ModuleTranslation::convertDialectAttributes(Operation
*op
) {
936 for (NamedAttribute attribute
: op
->getDialectAttrs())
937 if (failed(iface
.amendOperation(op
, attribute
, *this)))
942 /// Converts the function attributes from LLVMFuncOp and attaches them to the
944 static void convertFunctionAttributes(LLVMFuncOp func
,
945 llvm::Function
*llvmFunc
) {
946 if (!func
.getMemory())
949 MemoryEffectsAttr memEffects
= func
.getMemoryAttr();
951 // Add memory effects incrementally.
952 llvm::MemoryEffects newMemEffects
=
953 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem
,
954 convertModRefInfoToLLVM(memEffects
.getArgMem()));
955 newMemEffects
|= llvm::MemoryEffects(
956 llvm::MemoryEffects::Location::InaccessibleMem
,
957 convertModRefInfoToLLVM(memEffects
.getInaccessibleMem()));
959 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other
,
960 convertModRefInfoToLLVM(memEffects
.getOther()));
961 llvmFunc
->setMemoryEffects(newMemEffects
);
965 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs
) {
966 llvm::AttrBuilder
attrBuilder(llvmModule
->getContext());
968 for (auto [llvmKind
, mlirName
] : getAttrKindToNameMapping()) {
969 Attribute attr
= paramAttrs
.get(mlirName
);
970 // Skip attributes that are not present.
974 // NOTE: C++17 does not support capturing structured bindings.
975 llvm::Attribute::AttrKind llvmKindCap
= llvmKind
;
977 llvm::TypeSwitch
<Attribute
>(attr
)
978 .Case
<TypeAttr
>([&](auto typeAttr
) {
979 attrBuilder
.addTypeAttr(llvmKindCap
,
980 convertType(typeAttr
.getValue()));
982 .Case
<IntegerAttr
>([&](auto intAttr
) {
983 attrBuilder
.addRawIntAttr(llvmKindCap
, intAttr
.getInt());
985 .Case
<UnitAttr
>([&](auto) { attrBuilder
.addAttribute(llvmKindCap
); });
991 LogicalResult
ModuleTranslation::convertFunctionSignatures() {
992 // Declare all functions first because there may be function calls that form a
993 // call graph with cycles, or global initializers that reference functions.
994 for (auto function
: getModuleBody(mlirModule
).getOps
<LLVMFuncOp
>()) {
995 llvm::FunctionCallee llvmFuncCst
= llvmModule
->getOrInsertFunction(
997 cast
<llvm::FunctionType
>(convertType(function
.getFunctionType())));
998 llvm::Function
*llvmFunc
= cast
<llvm::Function
>(llvmFuncCst
.getCallee());
999 llvmFunc
->setLinkage(convertLinkageToLLVM(function
.getLinkage()));
1000 llvmFunc
->setCallingConv(convertCConvToLLVM(function
.getCConv()));
1001 mapFunction(function
.getName(), llvmFunc
);
1002 addRuntimePreemptionSpecifier(function
.getDsoLocal(), llvmFunc
);
1004 // Convert function attributes.
1005 convertFunctionAttributes(function
, llvmFunc
);
1007 // Convert function_entry_count attribute to metadata.
1008 if (std::optional
<uint64_t> entryCount
= function
.getFunctionEntryCount())
1009 llvmFunc
->setEntryCount(entryCount
.value());
1011 // Convert result attributes.
1012 if (ArrayAttr allResultAttrs
= function
.getAllResultAttrs()) {
1013 DictionaryAttr resultAttrs
= cast
<DictionaryAttr
>(allResultAttrs
[0]);
1014 llvmFunc
->addRetAttrs(convertParameterAttrs(resultAttrs
));
1017 // Convert argument attributes.
1018 for (auto [argIdx
, llvmArg
] : llvm::enumerate(llvmFunc
->args())) {
1019 if (DictionaryAttr argAttrs
= function
.getArgAttrDict(argIdx
)) {
1020 llvm::AttrBuilder attrBuilder
= convertParameterAttrs(argAttrs
);
1021 llvmArg
.addAttrs(attrBuilder
);
1025 // Forward the pass-through attributes to LLVM.
1026 if (failed(forwardPassthroughAttributes(
1027 function
.getLoc(), function
.getPassthrough(), llvmFunc
)))
1030 // Convert visibility attribute.
1031 llvmFunc
->setVisibility(convertVisibilityToLLVM(function
.getVisibility_()));
1033 // Convert the comdat attribute.
1034 if (std::optional
<mlir::SymbolRefAttr
> comdat
= function
.getComdat()) {
1035 auto selectorOp
= cast
<ComdatSelectorOp
>(
1036 SymbolTable::lookupNearestSymbolFrom(function
, *comdat
));
1037 llvmFunc
->setComdat(comdatMapping
.lookup(selectorOp
));
1040 if (auto gc
= function
.getGarbageCollector())
1041 llvmFunc
->setGC(gc
->str());
1043 if (auto unnamedAddr
= function
.getUnnamedAddr())
1044 llvmFunc
->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr
));
1046 if (auto alignment
= function
.getAlignment())
1047 llvmFunc
->setAlignment(llvm::MaybeAlign(*alignment
));
1053 LogicalResult
ModuleTranslation::convertFunctions() {
1054 // Convert functions.
1055 for (auto function
: getModuleBody(mlirModule
).getOps
<LLVMFuncOp
>()) {
1056 // Do not convert external functions, but do process dialect attributes
1057 // attached to them.
1058 if (function
.isExternal()) {
1059 if (failed(convertDialectAttributes(function
)))
1064 if (failed(convertOneFunction(function
)))
1071 LogicalResult
ModuleTranslation::convertComdats() {
1072 for (auto comdatOp
: getModuleBody(mlirModule
).getOps
<ComdatOp
>()) {
1073 for (auto selectorOp
: comdatOp
.getOps
<ComdatSelectorOp
>()) {
1074 llvm::Module
*module
= getLLVMModule();
1075 if (module
->getComdatSymbolTable().contains(selectorOp
.getSymName()))
1076 return emitError(selectorOp
.getLoc())
1077 << "comdat selection symbols must be unique even in different "
1079 llvm::Comdat
*comdat
= module
->getOrInsertComdat(selectorOp
.getSymName());
1080 comdat
->setSelectionKind(convertComdatToLLVM(selectorOp
.getComdat()));
1081 comdatMapping
.try_emplace(selectorOp
, comdat
);
1087 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op
,
1088 llvm::Instruction
*inst
) {
1089 if (llvm::MDNode
*node
= loopAnnotationTranslation
->getAccessGroups(op
))
1090 inst
->setMetadata(llvm::LLVMContext::MD_access_group
, node
);
1094 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr
) {
1095 auto [scopeIt
, scopeInserted
] =
1096 aliasScopeMetadataMapping
.try_emplace(aliasScopeAttr
, nullptr);
1098 return scopeIt
->second
;
1099 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1100 // Convert the domain metadata node if necessary.
1101 auto [domainIt
, insertedDomain
] = aliasDomainMetadataMapping
.try_emplace(
1102 aliasScopeAttr
.getDomain(), nullptr);
1103 if (insertedDomain
) {
1104 llvm::SmallVector
<llvm::Metadata
*, 2> operands
;
1105 // Placeholder for self-reference.
1106 operands
.push_back({});
1107 if (StringAttr description
= aliasScopeAttr
.getDomain().getDescription())
1108 operands
.push_back(llvm::MDString::get(ctx
, description
));
1109 domainIt
->second
= llvm::MDNode::get(ctx
, operands
);
1110 // Self-reference for uniqueness.
1111 domainIt
->second
->replaceOperandWith(0, domainIt
->second
);
1113 // Convert the scope metadata node.
1114 assert(domainIt
->second
&& "Scope's domain should already be valid");
1115 llvm::SmallVector
<llvm::Metadata
*, 3> operands
;
1116 // Placeholder for self-reference.
1117 operands
.push_back({});
1118 operands
.push_back(domainIt
->second
);
1119 if (StringAttr description
= aliasScopeAttr
.getDescription())
1120 operands
.push_back(llvm::MDString::get(ctx
, description
));
1121 scopeIt
->second
= llvm::MDNode::get(ctx
, operands
);
1122 // Self-reference for uniqueness.
1123 scopeIt
->second
->replaceOperandWith(0, scopeIt
->second
);
1124 return scopeIt
->second
;
1127 llvm::MDNode
*ModuleTranslation::getOrCreateAliasScopes(
1128 ArrayRef
<AliasScopeAttr
> aliasScopeAttrs
) {
1129 SmallVector
<llvm::Metadata
*> nodes
;
1130 nodes
.reserve(aliasScopeAttrs
.size());
1131 for (AliasScopeAttr aliasScopeAttr
: aliasScopeAttrs
)
1132 nodes
.push_back(getOrCreateAliasScope(aliasScopeAttr
));
1133 return llvm::MDNode::get(getLLVMContext(), nodes
);
1136 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op
,
1137 llvm::Instruction
*inst
) {
1138 auto populateScopeMetadata
= [&](ArrayAttr aliasScopeAttrs
, unsigned kind
) {
1139 if (!aliasScopeAttrs
|| aliasScopeAttrs
.empty())
1141 llvm::MDNode
*node
= getOrCreateAliasScopes(
1142 llvm::to_vector(aliasScopeAttrs
.getAsRange
<AliasScopeAttr
>()));
1143 inst
->setMetadata(kind
, node
);
1146 populateScopeMetadata(op
.getAliasScopesOrNull(),
1147 llvm::LLVMContext::MD_alias_scope
);
1148 populateScopeMetadata(op
.getNoAliasScopesOrNull(),
1149 llvm::LLVMContext::MD_noalias
);
1152 llvm::MDNode
*ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr
) const {
1153 return tbaaMetadataMapping
.lookup(tbaaAttr
);
1156 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op
,
1157 llvm::Instruction
*inst
) {
1158 ArrayAttr tagRefs
= op
.getTBAATagsOrNull();
1159 if (!tagRefs
|| tagRefs
.empty())
1162 // LLVM IR currently does not support attaching more than one TBAA access tag
1163 // to a memory accessing instruction. It may be useful to support this in
1164 // future, but for the time being just ignore the metadata if MLIR operation
1165 // has multiple access tags.
1166 if (tagRefs
.size() > 1) {
1167 op
.emitWarning() << "TBAA access tags were not translated, because LLVM "
1168 "IR only supports a single tag per instruction";
1172 llvm::MDNode
*node
= getTBAANode(cast
<TBAATagAttr
>(tagRefs
[0]));
1173 inst
->setMetadata(llvm::LLVMContext::MD_tbaa
, node
);
1176 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op
) {
1177 DenseI32ArrayAttr weightsAttr
= op
.getBranchWeightsOrNull();
1181 llvm::Instruction
*inst
= isa
<CallOp
>(op
) ? lookupCall(op
) : lookupBranch(op
);
1182 assert(inst
&& "expected the operation to have a mapping to an instruction");
1183 SmallVector
<uint32_t> weights(weightsAttr
.asArrayRef());
1185 llvm::LLVMContext::MD_prof
,
1186 llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights
));
1189 LogicalResult
ModuleTranslation::createTBAAMetadata() {
1190 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1191 llvm::IntegerType
*offsetTy
= llvm::IntegerType::get(ctx
, 64);
1193 // Walk the entire module and create all metadata nodes for the TBAA
1194 // attributes. The code below relies on two invariants of the
1195 // `AttrTypeWalker`:
1196 // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1197 // this ensures that any lookups into `tbaaMetadataMapping` for child
1198 // attributes succeed.
1199 // 2. Attributes are only ever visited once: This way we don't leak any
1200 // LLVM metadata instances.
1201 AttrTypeWalker walker
;
1202 walker
.addWalk([&](TBAARootAttr root
) {
1203 tbaaMetadataMapping
.insert(
1204 {root
, llvm::MDNode::get(ctx
, llvm::MDString::get(ctx
, root
.getId()))});
1207 walker
.addWalk([&](TBAATypeDescriptorAttr descriptor
) {
1208 SmallVector
<llvm::Metadata
*> operands
;
1209 operands
.push_back(llvm::MDString::get(ctx
, descriptor
.getId()));
1210 for (TBAAMemberAttr member
: descriptor
.getMembers()) {
1211 operands
.push_back(tbaaMetadataMapping
.lookup(member
.getTypeDesc()));
1212 operands
.push_back(llvm::ConstantAsMetadata::get(
1213 llvm::ConstantInt::get(offsetTy
, member
.getOffset())));
1216 tbaaMetadataMapping
.insert({descriptor
, llvm::MDNode::get(ctx
, operands
)});
1219 walker
.addWalk([&](TBAATagAttr tag
) {
1220 SmallVector
<llvm::Metadata
*> operands
;
1222 operands
.push_back(tbaaMetadataMapping
.lookup(tag
.getBaseType()));
1223 operands
.push_back(tbaaMetadataMapping
.lookup(tag
.getAccessType()));
1225 operands
.push_back(llvm::ConstantAsMetadata::get(
1226 llvm::ConstantInt::get(offsetTy
, tag
.getOffset())));
1227 if (tag
.getConstant())
1229 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy
, 1)));
1231 tbaaMetadataMapping
.insert({tag
, llvm::MDNode::get(ctx
, operands
)});
1234 mlirModule
->walk([&](AliasAnalysisOpInterface analysisOpInterface
) {
1235 if (auto attr
= analysisOpInterface
.getTBAATagsOrNull())
1242 void ModuleTranslation::setLoopMetadata(Operation
*op
,
1243 llvm::Instruction
*inst
) {
1244 LoopAnnotationAttr attr
=
1245 TypeSwitch
<Operation
*, LoopAnnotationAttr
>(op
)
1246 .Case
<LLVM::BrOp
, LLVM::CondBrOp
>(
1247 [](auto branchOp
) { return branchOp
.getLoopAnnotationAttr(); });
1250 llvm::MDNode
*loopMD
=
1251 loopAnnotationTranslation
->translateLoopAnnotation(attr
, op
);
1252 inst
->setMetadata(llvm::LLVMContext::MD_loop
, loopMD
);
1255 llvm::Type
*ModuleTranslation::convertType(Type type
) {
1256 return typeTranslator
.translateType(type
);
1259 /// A helper to look up remapped operands in the value remapping table.
1260 SmallVector
<llvm::Value
*> ModuleTranslation::lookupValues(ValueRange values
) {
1261 SmallVector
<llvm::Value
*> remapped
;
1262 remapped
.reserve(values
.size());
1263 for (Value v
: values
)
1264 remapped
.push_back(lookupValue(v
));
1268 llvm::OpenMPIRBuilder
*ModuleTranslation::getOpenMPBuilder() {
1270 ompBuilder
= std::make_unique
<llvm::OpenMPIRBuilder
>(*llvmModule
);
1271 ompBuilder
->initialize();
1273 // Flags represented as top-level OpenMP dialect attributes are set in
1274 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
1275 // the default configuration.
1276 ompBuilder
->setConfig(llvm::OpenMPIRBuilderConfig(
1277 /* IsTargetDevice = */ false, /* IsGPU = */ false,
1278 /* OpenMPOffloadMandatory = */ false,
1279 /* HasRequiresReverseOffload = */ false,
1280 /* HasRequiresUnifiedAddress = */ false,
1281 /* HasRequiresUnifiedSharedMemory = */ false,
1282 /* HasRequiresDynamicAllocators = */ false));
1284 return ompBuilder
.get();
1287 llvm::DILocation
*ModuleTranslation::translateLoc(Location loc
,
1288 llvm::DILocalScope
*scope
) {
1289 return debugTranslation
->translateLoc(loc
, scope
);
1292 llvm::Metadata
*ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr
) {
1293 return debugTranslation
->translate(attr
);
1297 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name
) {
1298 return llvmModule
->getOrInsertNamedMetadata(name
);
1301 void ModuleTranslation::StackFrame::anchor() {}
1303 static std::unique_ptr
<llvm::Module
>
1304 prepareLLVMModule(Operation
*m
, llvm::LLVMContext
&llvmContext
,
1306 m
->getContext()->getOrLoadDialect
<LLVM::LLVMDialect
>();
1307 auto llvmModule
= std::make_unique
<llvm::Module
>(name
, llvmContext
);
1308 if (auto dataLayoutAttr
=
1309 m
->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1310 llvmModule
->setDataLayout(cast
<StringAttr
>(dataLayoutAttr
).getValue());
1312 FailureOr
<llvm::DataLayout
> llvmDataLayout(llvm::DataLayout(""));
1313 if (auto iface
= dyn_cast
<DataLayoutOpInterface
>(m
)) {
1314 if (DataLayoutSpecInterface spec
= iface
.getDataLayoutSpec()) {
1316 translateDataLayout(spec
, DataLayout(iface
), m
->getLoc());
1318 } else if (auto mod
= dyn_cast
<ModuleOp
>(m
)) {
1319 if (DataLayoutSpecInterface spec
= mod
.getDataLayoutSpec()) {
1321 translateDataLayout(spec
, DataLayout(mod
), m
->getLoc());
1324 if (failed(llvmDataLayout
))
1326 llvmModule
->setDataLayout(*llvmDataLayout
);
1328 if (auto targetTripleAttr
=
1329 m
->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1330 llvmModule
->setTargetTriple(cast
<StringAttr
>(targetTripleAttr
).getValue());
1332 // Inject declarations for `malloc` and `free` functions that can be used in
1333 // memref allocation/deallocation coming from standard ops lowering.
1334 llvm::IRBuilder
<> builder(llvmContext
);
1335 llvmModule
->getOrInsertFunction("malloc", builder
.getInt8PtrTy(),
1336 builder
.getInt64Ty());
1337 llvmModule
->getOrInsertFunction("free", builder
.getVoidTy(),
1338 builder
.getInt8PtrTy());
1343 std::unique_ptr
<llvm::Module
>
1344 mlir::translateModuleToLLVMIR(Operation
*module
, llvm::LLVMContext
&llvmContext
,
1346 if (!satisfiesLLVMModule(module
)) {
1347 module
->emitOpError("can not be translated to an LLVMIR module");
1351 std::unique_ptr
<llvm::Module
> llvmModule
=
1352 prepareLLVMModule(module
, llvmContext
, name
);
1356 LLVM::ensureDistinctSuccessors(module
);
1358 ModuleTranslation
translator(module
, std::move(llvmModule
));
1359 llvm::IRBuilder
<> llvmBuilder(llvmContext
);
1361 // Convert module before functions and operations inside, so dialect
1362 // attributes can be used to change dialect-specific global configurations via
1363 // `amendOperation()`. These configurations can then influence the translation
1364 // of operations afterwards.
1365 if (failed(translator
.convertOperation(*module
, llvmBuilder
)))
1368 if (failed(translator
.convertComdats()))
1370 if (failed(translator
.convertFunctionSignatures()))
1372 if (failed(translator
.convertGlobals()))
1374 if (failed(translator
.createTBAAMetadata()))
1377 // Convert other top-level operations if possible.
1378 for (Operation
&o
: getModuleBody(module
).getOperations()) {
1379 if (!isa
<LLVM::LLVMFuncOp
, LLVM::GlobalOp
, LLVM::GlobalCtorsOp
,
1380 LLVM::GlobalDtorsOp
, LLVM::ComdatOp
>(&o
) &&
1381 !o
.hasTrait
<OpTrait::IsTerminator
>() &&
1382 failed(translator
.convertOperation(o
, llvmBuilder
))) {
1387 // Operations in function bodies with symbolic references must be converted
1388 // after the top-level operations they refer to are declared, so we do it
1390 if (failed(translator
.convertFunctions()))
1393 if (llvm::verifyModule(*translator
.llvmModule
, &llvm::errs()))
1396 return std::move(translator
.llvmModule
);