1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
18 #include "LoopAnnotationTranslation.h"
19 #include "mlir/Analysis/TopologicalSortUtils.h"
20 #include "mlir/Dialect/DLTI/DLTI.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
23 #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
24 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
25 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
27 #include "mlir/IR/AttrTypeSubElements.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/DialectResourceBlobManager.h"
32 #include "mlir/IR/RegionGraphTraits.h"
33 #include "mlir/Support/LLVM.h"
34 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
35 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
37 #include "llvm/ADT/PostOrderIterator.h"
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
42 #include "llvm/IR/BasicBlock.h"
43 #include "llvm/IR/CFG.h"
44 #include "llvm/IR/Constants.h"
45 #include "llvm/IR/DerivedTypes.h"
46 #include "llvm/IR/IRBuilder.h"
47 #include "llvm/IR/InlineAsm.h"
48 #include "llvm/IR/IntrinsicsNVPTX.h"
49 #include "llvm/IR/LLVMContext.h"
50 #include "llvm/IR/MDBuilder.h"
51 #include "llvm/IR/Module.h"
52 #include "llvm/IR/Verifier.h"
53 #include "llvm/Support/Debug.h"
54 #include "llvm/Support/raw_ostream.h"
55 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
56 #include "llvm/Transforms/Utils/Cloning.h"
57 #include "llvm/Transforms/Utils/ModuleUtils.h"
61 #define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
64 using namespace mlir::LLVM
;
65 using namespace mlir::LLVM::detail
;
67 extern llvm::cl::opt
<bool> UseNewDbgInfoFormat
;
69 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
72 /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
73 /// instructions that are created for future reference.
75 /// This is intended to be used with the `CollectionScope` RAII object:
77 /// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
79 /// InstructionCapturingInserter::CollectionScope scope(builder);
80 /// // Call IRBuilder methods as usual.
82 /// // This will return a list of all instructions created by the builder,
83 /// // in order of creation.
84 /// builder.getInserter().getCapturedInstructions();
86 /// // This will return an empty list.
87 /// builder.getInserter().getCapturedInstructions();
89 /// The capturing functionality is _disabled_ by default for performance
90 /// consideration. It needs to be explicitly enabled, which is achieved by
91 /// creating a `CollectionScope`.
92 class InstructionCapturingInserter
: public llvm::IRBuilderCallbackInserter
{
94 /// Constructs the inserter.
95 InstructionCapturingInserter()
96 : llvm::IRBuilderCallbackInserter([this](llvm::Instruction
*instruction
) {
97 if (LLVM_LIKELY(enabled
))
98 capturedInstructions
.push_back(instruction
);
101 /// Returns the list of LLVM IR instructions captured since the last cleanup.
102 ArrayRef
<llvm::Instruction
*> getCapturedInstructions() const {
103 return capturedInstructions
;
106 /// Clears the list of captured LLVM IR instructions.
107 void clearCapturedInstructions() { capturedInstructions
.clear(); }
109 /// RAII object enabling the capture of created LLVM IR instructions.
110 class CollectionScope
{
112 /// Creates the scope for the given inserter.
113 CollectionScope(llvm::IRBuilderBase
&irBuilder
, bool isBuilderCapturing
);
118 ArrayRef
<llvm::Instruction
*> getCapturedInstructions() {
121 return inserter
->getCapturedInstructions();
125 /// Back reference to the inserter.
126 InstructionCapturingInserter
*inserter
= nullptr;
128 /// List of instructions in the inserter prior to this scope.
129 SmallVector
<llvm::Instruction
*> previouslyCollectedInstructions
;
131 /// Whether the inserter was enabled prior to this scope.
135 /// Enable or disable the capturing mechanism.
136 void setEnabled(bool enabled
= true) { this->enabled
= enabled
; }
139 /// List of captured instructions.
140 SmallVector
<llvm::Instruction
*> capturedInstructions
;
142 /// Whether the collection is enabled.
143 bool enabled
= false;
146 using CapturingIRBuilder
=
147 llvm::IRBuilder
<llvm::ConstantFolder
, InstructionCapturingInserter
>;
150 InstructionCapturingInserter::CollectionScope::CollectionScope(
151 llvm::IRBuilderBase
&irBuilder
, bool isBuilderCapturing
) {
153 if (!isBuilderCapturing
)
156 auto &capturingIRBuilder
= static_cast<CapturingIRBuilder
&>(irBuilder
);
157 inserter
= &capturingIRBuilder
.getInserter();
158 wasEnabled
= inserter
->enabled
;
160 previouslyCollectedInstructions
.swap(inserter
->capturedInstructions
);
161 inserter
->setEnabled(true);
164 InstructionCapturingInserter::CollectionScope::~CollectionScope() {
168 previouslyCollectedInstructions
.swap(inserter
->capturedInstructions
);
169 // If collection was enabled (likely in another, surrounding scope), keep
170 // the instructions collected in this scope.
172 llvm::append_range(inserter
->capturedInstructions
,
173 previouslyCollectedInstructions
);
175 inserter
->setEnabled(wasEnabled
);
178 /// Translates the given data layout spec attribute to the LLVM IR data layout.
179 /// Only integer, float, pointer and endianness entries are currently supported.
180 static FailureOr
<llvm::DataLayout
>
181 translateDataLayout(DataLayoutSpecInterface attribute
,
182 const DataLayout
&dataLayout
,
183 std::optional
<Location
> loc
= std::nullopt
) {
185 loc
= UnknownLoc::get(attribute
.getContext());
187 // Translate the endianness attribute.
188 std::string llvmDataLayout
;
189 llvm::raw_string_ostream
layoutStream(llvmDataLayout
);
190 for (DataLayoutEntryInterface entry
: attribute
.getEntries()) {
191 auto key
= llvm::dyn_cast_if_present
<StringAttr
>(entry
.getKey());
194 if (key
.getValue() == DLTIDialect::kDataLayoutEndiannessKey
) {
195 auto value
= cast
<StringAttr
>(entry
.getValue());
196 bool isLittleEndian
=
197 value
.getValue() == DLTIDialect::kDataLayoutEndiannessLittle
;
198 layoutStream
<< "-" << (isLittleEndian
? "e" : "E");
201 if (key
.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey
) {
202 auto value
= cast
<IntegerAttr
>(entry
.getValue());
203 uint64_t space
= value
.getValue().getZExtValue();
204 // Skip the default address space.
207 layoutStream
<< "-P" << space
;
210 if (key
.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey
) {
211 auto value
= cast
<IntegerAttr
>(entry
.getValue());
212 uint64_t space
= value
.getValue().getZExtValue();
213 // Skip the default address space.
216 layoutStream
<< "-G" << space
;
219 if (key
.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey
) {
220 auto value
= cast
<IntegerAttr
>(entry
.getValue());
221 uint64_t space
= value
.getValue().getZExtValue();
222 // Skip the default address space.
225 layoutStream
<< "-A" << space
;
228 if (key
.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey
) {
229 auto value
= cast
<IntegerAttr
>(entry
.getValue());
230 uint64_t alignment
= value
.getValue().getZExtValue();
231 // Skip the default stack alignment.
234 layoutStream
<< "-S" << alignment
;
237 emitError(*loc
) << "unsupported data layout key " << key
;
241 // Go through the list of entries to check which types are explicitly
242 // specified in entries. Where possible, data layout queries are used instead
243 // of directly inspecting the entries.
244 for (DataLayoutEntryInterface entry
: attribute
.getEntries()) {
245 auto type
= llvm::dyn_cast_if_present
<Type
>(entry
.getKey());
248 // Data layout for the index type is irrelevant at this point.
249 if (isa
<IndexType
>(type
))
252 LogicalResult result
=
253 llvm::TypeSwitch
<Type
, LogicalResult
>(type
)
254 .Case
<IntegerType
, Float16Type
, Float32Type
, Float64Type
,
255 Float80Type
, Float128Type
>([&](Type type
) -> LogicalResult
{
256 if (auto intType
= dyn_cast
<IntegerType
>(type
)) {
257 if (intType
.getSignedness() != IntegerType::Signless
)
258 return emitError(*loc
)
259 << "unsupported data layout for non-signless integer "
265 uint64_t size
= dataLayout
.getTypeSizeInBits(type
);
266 uint64_t abi
= dataLayout
.getTypeABIAlignment(type
) * 8u;
268 dataLayout
.getTypePreferredAlignment(type
) * 8u;
269 layoutStream
<< size
<< ":" << abi
;
270 if (abi
!= preferred
)
271 layoutStream
<< ":" << preferred
;
274 .Case([&](LLVMPointerType type
) {
275 layoutStream
<< "p" << type
.getAddressSpace() << ":";
276 uint64_t size
= dataLayout
.getTypeSizeInBits(type
);
277 uint64_t abi
= dataLayout
.getTypeABIAlignment(type
) * 8u;
279 dataLayout
.getTypePreferredAlignment(type
) * 8u;
280 uint64_t index
= *dataLayout
.getTypeIndexBitwidth(type
);
281 layoutStream
<< size
<< ":" << abi
<< ":" << preferred
<< ":"
285 .Default([loc
](Type type
) {
286 return emitError(*loc
)
287 << "unsupported type in data layout: " << type
;
292 StringRef
layoutSpec(llvmDataLayout
);
293 if (layoutSpec
.starts_with("-"))
294 layoutSpec
= layoutSpec
.drop_front();
296 return llvm::DataLayout(layoutSpec
);
299 /// Builds a constant of a sequential LLVM type `type`, potentially containing
300 /// other sequential types recursively, from the individual constant values
301 /// provided in `constants`. `shape` contains the number of elements in nested
302 /// sequential types. Reports errors at `loc` and returns nullptr on error.
303 static llvm::Constant
*
304 buildSequentialConstant(ArrayRef
<llvm::Constant
*> &constants
,
305 ArrayRef
<int64_t> shape
, llvm::Type
*type
,
308 llvm::Constant
*result
= constants
.front();
309 constants
= constants
.drop_front();
313 llvm::Type
*elementType
;
314 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(type
)) {
315 elementType
= arrayTy
->getElementType();
316 } else if (auto *vectorTy
= dyn_cast
<llvm::VectorType
>(type
)) {
317 elementType
= vectorTy
->getElementType();
319 emitError(loc
) << "expected sequential LLVM types wrapping a scalar";
323 SmallVector
<llvm::Constant
*, 8> nested
;
324 nested
.reserve(shape
.front());
325 for (int64_t i
= 0; i
< shape
.front(); ++i
) {
326 nested
.push_back(buildSequentialConstant(constants
, shape
.drop_front(),
332 if (shape
.size() == 1 && type
->isVectorTy())
333 return llvm::ConstantVector::get(nested
);
334 return llvm::ConstantArray::get(
335 llvm::ArrayType::get(elementType
, shape
.front()), nested
);
338 /// Returns the first non-sequential type nested in sequential types.
339 static llvm::Type
*getInnermostElementType(llvm::Type
*type
) {
341 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(type
)) {
342 type
= arrayTy
->getElementType();
343 } else if (auto *vectorTy
= dyn_cast
<llvm::VectorType
>(type
)) {
344 type
= vectorTy
->getElementType();
351 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
352 /// storage if possible. This supports elements attributes of tensor or vector
353 /// type and avoids constructing separate objects for individual values of the
354 /// innermost dimension. Constants for other dimensions are still constructed
355 /// recursively. Returns null if constructing from raw data is not supported for
356 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
357 /// other errors at `loc`.
358 static llvm::Constant
*
359 convertDenseElementsAttr(Location loc
, DenseElementsAttr denseElementsAttr
,
360 llvm::Type
*llvmType
,
361 const ModuleTranslation
&moduleTranslation
) {
362 if (!denseElementsAttr
)
365 llvm::Type
*innermostLLVMType
= getInnermostElementType(llvmType
);
366 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType
))
369 ShapedType type
= denseElementsAttr
.getType();
370 if (type
.getNumElements() == 0)
373 // Check that the raw data size matches what is expected for the scalar size.
374 // TODO: in theory, we could repack the data here to keep constructing from
376 // TODO: we may also need to consider endianness when cross-compiling to an
377 // architecture where it is different.
378 int64_t elementByteSize
= denseElementsAttr
.getRawData().size() /
379 denseElementsAttr
.getNumElements();
380 if (8 * elementByteSize
!= innermostLLVMType
->getScalarSizeInBits())
383 // Compute the shape of all dimensions but the innermost. Note that the
384 // innermost dimension may be that of the vector element type.
385 bool hasVectorElementType
= isa
<VectorType
>(type
.getElementType());
386 int64_t numAggregates
=
387 denseElementsAttr
.getNumElements() /
388 (hasVectorElementType
? 1
389 : denseElementsAttr
.getType().getShape().back());
390 ArrayRef
<int64_t> outerShape
= type
.getShape();
391 if (!hasVectorElementType
)
392 outerShape
= outerShape
.drop_back();
394 // Handle the case of vector splat, LLVM has special support for it.
395 if (denseElementsAttr
.isSplat() &&
396 (isa
<VectorType
>(type
) || hasVectorElementType
)) {
397 llvm::Constant
*splatValue
= LLVM::detail::getLLVMConstant(
398 innermostLLVMType
, denseElementsAttr
.getSplatValue
<Attribute
>(), loc
,
400 llvm::Constant
*splatVector
=
401 llvm::ConstantDataVector::getSplat(0, splatValue
);
402 SmallVector
<llvm::Constant
*> constants(numAggregates
, splatVector
);
403 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
404 return buildSequentialConstant(constantsRef
, outerShape
, llvmType
, loc
);
406 if (denseElementsAttr
.isSplat())
409 // In case of non-splat, create a constructor for the innermost constant from
410 // a piece of raw data.
411 std::function
<llvm::Constant
*(StringRef
)> buildCstData
;
412 if (isa
<TensorType
>(type
)) {
413 auto vectorElementType
= dyn_cast
<VectorType
>(type
.getElementType());
414 if (vectorElementType
&& vectorElementType
.getRank() == 1) {
415 buildCstData
= [&](StringRef data
) {
416 return llvm::ConstantDataVector::getRaw(
417 data
, vectorElementType
.getShape().back(), innermostLLVMType
);
419 } else if (!vectorElementType
) {
420 buildCstData
= [&](StringRef data
) {
421 return llvm::ConstantDataArray::getRaw(data
, type
.getShape().back(),
425 } else if (isa
<VectorType
>(type
)) {
426 buildCstData
= [&](StringRef data
) {
427 return llvm::ConstantDataVector::getRaw(data
, type
.getShape().back(),
434 // Create innermost constants and defer to the default constant creation
435 // mechanism for other dimensions.
436 SmallVector
<llvm::Constant
*> constants
;
437 int64_t aggregateSize
= denseElementsAttr
.getType().getShape().back() *
438 (innermostLLVMType
->getScalarSizeInBits() / 8);
439 constants
.reserve(numAggregates
);
440 for (unsigned i
= 0; i
< numAggregates
; ++i
) {
441 StringRef
data(denseElementsAttr
.getRawData().data() + i
* aggregateSize
,
443 constants
.push_back(buildCstData(data
));
446 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
447 return buildSequentialConstant(constantsRef
, outerShape
, llvmType
, loc
);
450 /// Convert a dense resource elements attribute to an LLVM IR constant using its
451 /// raw data storage if possible. This supports elements attributes of tensor or
452 /// vector type and avoids constructing separate objects for individual values
453 /// of the innermost dimension. Constants for other dimensions are still
454 /// constructed recursively. Returns nullptr on failure and emits errors at
456 static llvm::Constant
*convertDenseResourceElementsAttr(
457 Location loc
, DenseResourceElementsAttr denseResourceAttr
,
458 llvm::Type
*llvmType
, const ModuleTranslation
&moduleTranslation
) {
459 assert(denseResourceAttr
&& "expected non-null attribute");
461 llvm::Type
*innermostLLVMType
= getInnermostElementType(llvmType
);
462 if (!llvm::ConstantDataSequential::isElementTypeCompatible(
463 innermostLLVMType
)) {
464 emitError(loc
, "no known conversion for innermost element type");
468 ShapedType type
= denseResourceAttr
.getType();
469 assert(type
.getNumElements() > 0 && "Expected non-empty elements attribute");
471 AsmResourceBlob
*blob
= denseResourceAttr
.getRawHandle().getBlob();
473 emitError(loc
, "resource does not exist");
477 ArrayRef
<char> rawData
= blob
->getData();
479 // Check that the raw data size matches what is expected for the scalar size.
480 // TODO: in theory, we could repack the data here to keep constructing from
482 // TODO: we may also need to consider endianness when cross-compiling to an
483 // architecture where it is different.
484 int64_t numElements
= denseResourceAttr
.getType().getNumElements();
485 int64_t elementByteSize
= rawData
.size() / numElements
;
486 if (8 * elementByteSize
!= innermostLLVMType
->getScalarSizeInBits()) {
487 emitError(loc
, "raw data size does not match element type size");
491 // Compute the shape of all dimensions but the innermost. Note that the
492 // innermost dimension may be that of the vector element type.
493 bool hasVectorElementType
= isa
<VectorType
>(type
.getElementType());
494 int64_t numAggregates
=
495 numElements
/ (hasVectorElementType
497 : denseResourceAttr
.getType().getShape().back());
498 ArrayRef
<int64_t> outerShape
= type
.getShape();
499 if (!hasVectorElementType
)
500 outerShape
= outerShape
.drop_back();
502 // Create a constructor for the innermost constant from a piece of raw data.
503 std::function
<llvm::Constant
*(StringRef
)> buildCstData
;
504 if (isa
<TensorType
>(type
)) {
505 auto vectorElementType
= dyn_cast
<VectorType
>(type
.getElementType());
506 if (vectorElementType
&& vectorElementType
.getRank() == 1) {
507 buildCstData
= [&](StringRef data
) {
508 return llvm::ConstantDataVector::getRaw(
509 data
, vectorElementType
.getShape().back(), innermostLLVMType
);
511 } else if (!vectorElementType
) {
512 buildCstData
= [&](StringRef data
) {
513 return llvm::ConstantDataArray::getRaw(data
, type
.getShape().back(),
517 } else if (isa
<VectorType
>(type
)) {
518 buildCstData
= [&](StringRef data
) {
519 return llvm::ConstantDataVector::getRaw(data
, type
.getShape().back(),
524 emitError(loc
, "unsupported dense_resource type");
528 // Create innermost constants and defer to the default constant creation
529 // mechanism for other dimensions.
530 SmallVector
<llvm::Constant
*> constants
;
531 int64_t aggregateSize
= denseResourceAttr
.getType().getShape().back() *
532 (innermostLLVMType
->getScalarSizeInBits() / 8);
533 constants
.reserve(numAggregates
);
534 for (unsigned i
= 0; i
< numAggregates
; ++i
) {
535 StringRef
data(rawData
.data() + i
* aggregateSize
, aggregateSize
);
536 constants
.push_back(buildCstData(data
));
539 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
540 return buildSequentialConstant(constantsRef
, outerShape
, llvmType
, loc
);
543 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
544 /// This currently supports integer, floating point, splat and dense element
545 /// attributes and combinations thereof. Also, an array attribute with two
546 /// elements is supported to represent a complex constant. In case of error,
547 /// report it to `loc` and return nullptr.
548 llvm::Constant
*mlir::LLVM::detail::getLLVMConstant(
549 llvm::Type
*llvmType
, Attribute attr
, Location loc
,
550 const ModuleTranslation
&moduleTranslation
) {
552 return llvm::UndefValue::get(llvmType
);
553 if (auto *structType
= dyn_cast
<::llvm::StructType
>(llvmType
)) {
554 auto arrayAttr
= dyn_cast
<ArrayAttr
>(attr
);
556 emitError(loc
, "expected an array attribute for a struct constant");
559 SmallVector
<llvm::Constant
*> structElements
;
560 structElements
.reserve(structType
->getNumElements());
561 for (auto [elemType
, elemAttr
] :
562 zip_equal(structType
->elements(), arrayAttr
)) {
563 llvm::Constant
*element
=
564 getLLVMConstant(elemType
, elemAttr
, loc
, moduleTranslation
);
567 structElements
.push_back(element
);
569 return llvm::ConstantStruct::get(structType
, structElements
);
571 // For integer types, we allow a mismatch in sizes as the index type in
572 // MLIR might have a different size than the index type in the LLVM module.
573 if (auto intAttr
= dyn_cast
<IntegerAttr
>(attr
))
574 return llvm::ConstantInt::get(
576 intAttr
.getValue().sextOrTrunc(llvmType
->getIntegerBitWidth()));
577 if (auto floatAttr
= dyn_cast
<FloatAttr
>(attr
)) {
578 const llvm::fltSemantics
&sem
= floatAttr
.getValue().getSemantics();
579 // Special case for 8-bit floats, which are represented by integers due to
580 // the lack of native fp8 types in LLVM at the moment. Additionally, handle
581 // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
583 unsigned floatWidth
= APFloat::getSizeInBits(sem
);
584 if (llvmType
->isIntegerTy(floatWidth
))
585 return llvm::ConstantInt::get(llvmType
,
586 floatAttr
.getValue().bitcastToAPInt());
588 llvm::Type::getFloatingPointTy(llvmType
->getContext(),
589 floatAttr
.getValue().getSemantics())) {
590 emitError(loc
, "FloatAttr does not match expected type of the constant");
593 return llvm::ConstantFP::get(llvmType
, floatAttr
.getValue());
595 if (auto funcAttr
= dyn_cast
<FlatSymbolRefAttr
>(attr
))
596 return llvm::ConstantExpr::getBitCast(
597 moduleTranslation
.lookupFunction(funcAttr
.getValue()), llvmType
);
598 if (auto splatAttr
= dyn_cast
<SplatElementsAttr
>(attr
)) {
599 llvm::Type
*elementType
;
600 uint64_t numElements
;
601 bool isScalable
= false;
602 if (auto *arrayTy
= dyn_cast
<llvm::ArrayType
>(llvmType
)) {
603 elementType
= arrayTy
->getElementType();
604 numElements
= arrayTy
->getNumElements();
605 } else if (auto *fVectorTy
= dyn_cast
<llvm::FixedVectorType
>(llvmType
)) {
606 elementType
= fVectorTy
->getElementType();
607 numElements
= fVectorTy
->getNumElements();
608 } else if (auto *sVectorTy
= dyn_cast
<llvm::ScalableVectorType
>(llvmType
)) {
609 elementType
= sVectorTy
->getElementType();
610 numElements
= sVectorTy
->getMinNumElements();
613 llvm_unreachable("unrecognized constant vector type");
615 // Splat value is a scalar. Extract it only if the element type is not
616 // another sequence type. The recursion terminates because each step removes
617 // one outer sequential type.
618 bool elementTypeSequential
=
619 isa
<llvm::ArrayType
, llvm::VectorType
>(elementType
);
620 llvm::Constant
*child
= getLLVMConstant(
622 elementTypeSequential
? splatAttr
623 : splatAttr
.getSplatValue
<Attribute
>(),
624 loc
, moduleTranslation
);
627 if (llvmType
->isVectorTy())
628 return llvm::ConstantVector::getSplat(
629 llvm::ElementCount::get(numElements
, /*Scalable=*/isScalable
), child
);
630 if (llvmType
->isArrayTy()) {
631 auto *arrayType
= llvm::ArrayType::get(elementType
, numElements
);
632 if (child
->isZeroValue()) {
633 return llvm::ConstantAggregateZero::get(arrayType
);
635 if (llvm::ConstantDataSequential::isElementTypeCompatible(
637 // TODO: Handle all compatible types. This code only handles integer.
638 if (isa
<llvm::IntegerType
>(elementType
)) {
639 if (llvm::ConstantInt
*ci
= dyn_cast
<llvm::ConstantInt
>(child
)) {
640 if (ci
->getBitWidth() == 8) {
641 SmallVector
<int8_t> constants(numElements
, ci
->getZExtValue());
642 return llvm::ConstantDataArray::get(elementType
->getContext(),
645 if (ci
->getBitWidth() == 16) {
646 SmallVector
<int16_t> constants(numElements
, ci
->getZExtValue());
647 return llvm::ConstantDataArray::get(elementType
->getContext(),
650 if (ci
->getBitWidth() == 32) {
651 SmallVector
<int32_t> constants(numElements
, ci
->getZExtValue());
652 return llvm::ConstantDataArray::get(elementType
->getContext(),
655 if (ci
->getBitWidth() == 64) {
656 SmallVector
<int64_t> constants(numElements
, ci
->getZExtValue());
657 return llvm::ConstantDataArray::get(elementType
->getContext(),
663 // std::vector is used here to accomodate large number of elements that
664 // exceed SmallVector capacity.
665 std::vector
<llvm::Constant
*> constants(numElements
, child
);
666 return llvm::ConstantArray::get(arrayType
, constants
);
671 // Try using raw elements data if possible.
672 if (llvm::Constant
*result
=
673 convertDenseElementsAttr(loc
, dyn_cast
<DenseElementsAttr
>(attr
),
674 llvmType
, moduleTranslation
)) {
678 if (auto denseResourceAttr
= dyn_cast
<DenseResourceElementsAttr
>(attr
)) {
679 return convertDenseResourceElementsAttr(loc
, denseResourceAttr
, llvmType
,
683 // Fall back to element-by-element construction otherwise.
684 if (auto elementsAttr
= dyn_cast
<ElementsAttr
>(attr
)) {
685 assert(elementsAttr
.getShapedType().hasStaticShape());
686 assert(!elementsAttr
.getShapedType().getShape().empty() &&
687 "unexpected empty elements attribute shape");
689 SmallVector
<llvm::Constant
*, 8> constants
;
690 constants
.reserve(elementsAttr
.getNumElements());
691 llvm::Type
*innermostType
= getInnermostElementType(llvmType
);
692 for (auto n
: elementsAttr
.getValues
<Attribute
>()) {
694 getLLVMConstant(innermostType
, n
, loc
, moduleTranslation
));
695 if (!constants
.back())
698 ArrayRef
<llvm::Constant
*> constantsRef
= constants
;
699 llvm::Constant
*result
= buildSequentialConstant(
700 constantsRef
, elementsAttr
.getShapedType().getShape(), llvmType
, loc
);
701 assert(constantsRef
.empty() && "did not consume all elemental constants");
705 if (auto stringAttr
= dyn_cast
<StringAttr
>(attr
)) {
706 return llvm::ConstantDataArray::get(
707 moduleTranslation
.getLLVMContext(),
708 ArrayRef
<char>{stringAttr
.getValue().data(),
709 stringAttr
.getValue().size()});
711 emitError(loc
, "unsupported constant value");
715 ModuleTranslation::ModuleTranslation(Operation
*module
,
716 std::unique_ptr
<llvm::Module
> llvmModule
)
717 : mlirModule(module
), llvmModule(std::move(llvmModule
)),
719 std::make_unique
<DebugTranslation
>(module
, *this->llvmModule
)),
720 loopAnnotationTranslation(std::make_unique
<LoopAnnotationTranslation
>(
721 *this, *this->llvmModule
)),
722 typeTranslator(this->llvmModule
->getContext()),
723 iface(module
->getContext()) {
724 assert(satisfiesLLVMModule(mlirModule
) &&
725 "mlirModule should honor LLVM's module semantics.");
728 ModuleTranslation::~ModuleTranslation() {
730 ompBuilder
->finalize();
733 void ModuleTranslation::forgetMapping(Region
®ion
) {
734 SmallVector
<Region
*> toProcess
;
735 toProcess
.push_back(®ion
);
736 while (!toProcess
.empty()) {
737 Region
*current
= toProcess
.pop_back_val();
738 for (Block
&block
: *current
) {
739 blockMapping
.erase(&block
);
740 for (Value arg
: block
.getArguments())
741 valueMapping
.erase(arg
);
742 for (Operation
&op
: block
) {
743 for (Value value
: op
.getResults())
744 valueMapping
.erase(value
);
745 if (op
.hasSuccessors())
746 branchMapping
.erase(&op
);
747 if (isa
<LLVM::GlobalOp
>(op
))
748 globalsMapping
.erase(&op
);
749 if (isa
<LLVM::CallOp
>(op
))
750 callMapping
.erase(&op
);
753 llvm::map_range(op
.getRegions(), [](Region
&r
) { return &r
; }));
759 /// Get the SSA value passed to the current block from the terminator operation
760 /// of its predecessor.
761 static Value
getPHISourceValue(Block
*current
, Block
*pred
,
762 unsigned numArguments
, unsigned index
) {
763 Operation
&terminator
= *pred
->getTerminator();
764 if (isa
<LLVM::BrOp
>(terminator
))
765 return terminator
.getOperand(index
);
768 llvm::SmallPtrSet
<Block
*, 4> seenSuccessors
;
769 for (unsigned i
= 0, e
= terminator
.getNumSuccessors(); i
< e
; ++i
) {
770 Block
*successor
= terminator
.getSuccessor(i
);
771 auto branch
= cast
<BranchOpInterface
>(terminator
);
772 SuccessorOperands successorOperands
= branch
.getSuccessorOperands(i
);
774 (!seenSuccessors
.contains(successor
) || successorOperands
.empty()) &&
775 "successors with arguments in LLVM branches must be different blocks");
776 seenSuccessors
.insert(successor
);
780 // For instructions that branch based on a condition value, we need to take
781 // the operands for the branch that was taken.
782 if (auto condBranchOp
= dyn_cast
<LLVM::CondBrOp
>(terminator
)) {
783 // For conditional branches, we take the operands from either the "true" or
784 // the "false" branch.
785 return condBranchOp
.getSuccessor(0) == current
786 ? condBranchOp
.getTrueDestOperands()[index
]
787 : condBranchOp
.getFalseDestOperands()[index
];
790 if (auto switchOp
= dyn_cast
<LLVM::SwitchOp
>(terminator
)) {
791 // For switches, we take the operands from either the default case, or from
792 // the case branch that was taken.
793 if (switchOp
.getDefaultDestination() == current
)
794 return switchOp
.getDefaultOperands()[index
];
795 for (const auto &i
: llvm::enumerate(switchOp
.getCaseDestinations()))
796 if (i
.value() == current
)
797 return switchOp
.getCaseOperands(i
.index())[index
];
800 if (auto invokeOp
= dyn_cast
<LLVM::InvokeOp
>(terminator
)) {
801 return invokeOp
.getNormalDest() == current
802 ? invokeOp
.getNormalDestOperands()[index
]
803 : invokeOp
.getUnwindDestOperands()[index
];
807 "only branch, switch or invoke operations can be terminators "
808 "of a block that has successors");
811 /// Connect the PHI nodes to the results of preceding blocks.
812 void mlir::LLVM::detail::connectPHINodes(Region
®ion
,
813 const ModuleTranslation
&state
) {
814 // Skip the first block, it cannot be branched to and its arguments correspond
815 // to the arguments of the LLVM function.
816 for (Block
&bb
: llvm::drop_begin(region
)) {
817 llvm::BasicBlock
*llvmBB
= state
.lookupBlock(&bb
);
818 auto phis
= llvmBB
->phis();
819 auto numArguments
= bb
.getNumArguments();
820 assert(numArguments
== std::distance(phis
.begin(), phis
.end()));
821 for (auto [index
, phiNode
] : llvm::enumerate(phis
)) {
822 for (auto *pred
: bb
.getPredecessors()) {
823 // Find the LLVM IR block that contains the converted terminator
824 // instruction and use it in the PHI node. Note that this block is not
825 // necessarily the same as state.lookupBlock(pred), some operations
826 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
828 llvm::Instruction
*terminator
=
829 state
.lookupBranch(pred
->getTerminator());
830 assert(terminator
&& "missing the mapping for a terminator");
831 phiNode
.addIncoming(state
.lookupValue(getPHISourceValue(
832 &bb
, pred
, numArguments
, index
)),
833 terminator
->getParent());
839 llvm::CallInst
*mlir::LLVM::detail::createIntrinsicCall(
840 llvm::IRBuilderBase
&builder
, llvm::Intrinsic::ID intrinsic
,
841 ArrayRef
<llvm::Value
*> args
, ArrayRef
<llvm::Type
*> tys
) {
842 llvm::Module
*module
= builder
.GetInsertBlock()->getModule();
844 llvm::Intrinsic::getOrInsertDeclaration(module
, intrinsic
, tys
);
845 return builder
.CreateCall(fn
, args
);
848 llvm::CallInst
*mlir::LLVM::detail::createIntrinsicCall(
849 llvm::IRBuilderBase
&builder
, ModuleTranslation
&moduleTranslation
,
850 Operation
*intrOp
, llvm::Intrinsic::ID intrinsic
, unsigned numResults
,
851 ArrayRef
<unsigned> overloadedResults
, ArrayRef
<unsigned> overloadedOperands
,
852 ArrayRef
<unsigned> immArgPositions
,
853 ArrayRef
<StringLiteral
> immArgAttrNames
) {
854 assert(immArgPositions
.size() == immArgAttrNames
.size() &&
855 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
858 SmallVector
<llvm::OperandBundleDef
> opBundles
;
859 size_t numOpBundleOperands
= 0;
860 auto opBundleSizesAttr
= cast_if_present
<DenseI32ArrayAttr
>(
861 intrOp
->getAttr(LLVMDialect::getOpBundleSizesAttrName()));
862 auto opBundleTagsAttr
= cast_if_present
<ArrayAttr
>(
863 intrOp
->getAttr(LLVMDialect::getOpBundleTagsAttrName()));
865 if (opBundleSizesAttr
&& opBundleTagsAttr
) {
866 ArrayRef
<int> opBundleSizes
= opBundleSizesAttr
.asArrayRef();
867 assert(opBundleSizes
.size() == opBundleTagsAttr
.size() &&
868 "operand bundles and tags do not match");
870 numOpBundleOperands
=
871 std::accumulate(opBundleSizes
.begin(), opBundleSizes
.end(), size_t(0));
872 assert(numOpBundleOperands
<= intrOp
->getNumOperands() &&
873 "operand bundle operands is more than the number of operands");
875 ValueRange operands
= intrOp
->getOperands().take_back(numOpBundleOperands
);
876 size_t nextOperandIdx
= 0;
877 opBundles
.reserve(opBundleSizesAttr
.size());
879 for (auto [opBundleTagAttr
, bundleSize
] :
880 llvm::zip(opBundleTagsAttr
, opBundleSizes
)) {
881 auto bundleTag
= cast
<StringAttr
>(opBundleTagAttr
).str();
882 auto bundleOperands
= moduleTranslation
.lookupValues(
883 operands
.slice(nextOperandIdx
, bundleSize
));
884 opBundles
.emplace_back(std::move(bundleTag
), std::move(bundleOperands
));
885 nextOperandIdx
+= bundleSize
;
889 // Map operands and attributes to LLVM values.
890 auto opOperands
= intrOp
->getOperands().drop_back(numOpBundleOperands
);
891 auto operands
= moduleTranslation
.lookupValues(opOperands
);
892 SmallVector
<llvm::Value
*> args(immArgPositions
.size() + operands
.size());
893 for (auto [immArgPos
, immArgName
] :
894 llvm::zip(immArgPositions
, immArgAttrNames
)) {
895 auto attr
= llvm::cast
<TypedAttr
>(intrOp
->getAttr(immArgName
));
896 assert(attr
.getType().isIntOrFloat() && "expected int or float immarg");
897 auto *type
= moduleTranslation
.convertType(attr
.getType());
898 args
[immArgPos
] = LLVM::detail::getLLVMConstant(
899 type
, attr
, intrOp
->getLoc(), moduleTranslation
);
902 for (auto &arg
: args
) {
904 arg
= operands
[opArg
++];
907 // Resolve overloaded intrinsic declaration.
908 SmallVector
<llvm::Type
*> overloadedTypes
;
909 for (unsigned overloadedResultIdx
: overloadedResults
) {
910 if (numResults
> 1) {
911 // More than one result is mapped to an LLVM struct.
912 overloadedTypes
.push_back(moduleTranslation
.convertType(
913 llvm::cast
<LLVM::LLVMStructType
>(intrOp
->getResult(0).getType())
914 .getBody()[overloadedResultIdx
]));
916 overloadedTypes
.push_back(
917 moduleTranslation
.convertType(intrOp
->getResult(0).getType()));
920 for (unsigned overloadedOperandIdx
: overloadedOperands
)
921 overloadedTypes
.push_back(args
[overloadedOperandIdx
]->getType());
922 llvm::Module
*module
= builder
.GetInsertBlock()->getModule();
923 llvm::Function
*llvmIntr
= llvm::Intrinsic::getOrInsertDeclaration(
924 module
, intrinsic
, overloadedTypes
);
926 return builder
.CreateCall(llvmIntr
, args
, opBundles
);
929 /// Given a single MLIR operation, create the corresponding LLVM IR operation
930 /// using the `builder`.
931 LogicalResult
ModuleTranslation::convertOperation(Operation
&op
,
932 llvm::IRBuilderBase
&builder
,
933 bool recordInsertions
) {
934 const LLVMTranslationDialectInterface
*opIface
= iface
.getInterfaceFor(&op
);
936 return op
.emitError("cannot be converted to LLVM IR: missing "
937 "`LLVMTranslationDialectInterface` registration for "
941 InstructionCapturingInserter::CollectionScope
scope(builder
,
943 if (failed(opIface
->convertOperation(&op
, builder
, *this)))
944 return op
.emitError("LLVM Translation failed for operation: ")
947 return convertDialectAttributes(&op
, scope
.getCapturedInstructions());
950 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
951 /// to define values corresponding to the MLIR block arguments. These nodes
952 /// are not connected to the source basic blocks, which may not exist yet. Uses
953 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
954 /// been created for `bb` and included in the block mapping. Inserts new
955 /// instructions at the end of the block and leaves `builder` in a state
956 /// suitable for further insertion into the end of the block.
957 LogicalResult
ModuleTranslation::convertBlockImpl(Block
&bb
,
958 bool ignoreArguments
,
959 llvm::IRBuilderBase
&builder
,
960 bool recordInsertions
) {
961 builder
.SetInsertPoint(lookupBlock(&bb
));
962 auto *subprogram
= builder
.GetInsertBlock()->getParent()->getSubprogram();
964 // Before traversing operations, make block arguments available through
965 // value remapping and PHI nodes, but do not add incoming edges for the PHI
966 // nodes just yet: those values may be defined by this or following blocks.
967 // This step is omitted if "ignoreArguments" is set. The arguments of the
968 // first block have been already made available through the remapping of
969 // LLVM function arguments.
970 if (!ignoreArguments
) {
971 auto predecessors
= bb
.getPredecessors();
972 unsigned numPredecessors
=
973 std::distance(predecessors
.begin(), predecessors
.end());
974 for (auto arg
: bb
.getArguments()) {
975 auto wrappedType
= arg
.getType();
976 if (!isCompatibleType(wrappedType
))
977 return emitError(bb
.front().getLoc(),
978 "block argument does not have an LLVM type");
979 builder
.SetCurrentDebugLocation(
980 debugTranslation
->translateLoc(arg
.getLoc(), subprogram
));
981 llvm::Type
*type
= convertType(wrappedType
);
982 llvm::PHINode
*phi
= builder
.CreatePHI(type
, numPredecessors
);
987 // Traverse operations.
988 for (auto &op
: bb
) {
989 // Set the current debug location within the builder.
990 builder
.SetCurrentDebugLocation(
991 debugTranslation
->translateLoc(op
.getLoc(), subprogram
));
993 if (failed(convertOperation(op
, builder
, recordInsertions
)))
996 // Set the branch weight metadata on the translated instruction.
997 if (auto iface
= dyn_cast
<BranchWeightOpInterface
>(op
))
998 setBranchWeightsMetadata(iface
);
1004 /// A helper method to get the single Block in an operation honoring LLVM's
1005 /// module requirements.
1006 static Block
&getModuleBody(Operation
*module
) {
1007 return module
->getRegion(0).front();
1010 /// A helper method to decide if a constant must not be set as a global variable
1011 /// initializer. For an external linkage variable, the variable with an
1012 /// initializer is considered externally visible and defined in this module, the
1013 /// variable without an initializer is externally available and is defined
1015 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage
,
1016 llvm::Constant
*cst
) {
1017 return (linkage
== llvm::GlobalVariable::ExternalLinkage
&& !cst
) ||
1018 linkage
== llvm::GlobalVariable::ExternalWeakLinkage
;
1021 /// Sets the runtime preemption specifier of `gv` to dso_local if
1022 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
1023 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested
,
1024 llvm::GlobalValue
*gv
) {
1025 if (dsoLocalRequested
)
1026 gv
->setDSOLocal(true);
1029 /// Create named global variables that correspond to llvm.mlir.global
1030 /// definitions. Convert llvm.global_ctors and global_dtors ops.
1031 LogicalResult
ModuleTranslation::convertGlobals() {
1032 // Mapping from compile unit to its respective set of global variables.
1033 DenseMap
<llvm::DICompileUnit
*, SmallVector
<llvm::Metadata
*>> allGVars
;
1035 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>()) {
1036 llvm::Type
*type
= convertType(op
.getType());
1037 llvm::Constant
*cst
= nullptr;
1038 if (op
.getValueOrNull()) {
1039 // String attributes are treated separately because they cannot appear as
1040 // in-function constants and are thus not supported by getLLVMConstant.
1041 if (auto strAttr
= dyn_cast_or_null
<StringAttr
>(op
.getValueOrNull())) {
1042 cst
= llvm::ConstantDataArray::getString(
1043 llvmModule
->getContext(), strAttr
.getValue(), /*AddNull=*/false);
1044 type
= cst
->getType();
1045 } else if (!(cst
= getLLVMConstant(type
, op
.getValueOrNull(), op
.getLoc(),
1051 auto linkage
= convertLinkageToLLVM(op
.getLinkage());
1053 // LLVM IR requires constant with linkage other than external or weak
1054 // external to have initializers. If MLIR does not provide an initializer,
1055 // default to undef.
1056 bool dropInitializer
= shouldDropGlobalInitializer(linkage
, cst
);
1057 if (!dropInitializer
&& !cst
)
1058 cst
= llvm::UndefValue::get(type
);
1059 else if (dropInitializer
&& cst
)
1062 auto *var
= new llvm::GlobalVariable(
1063 *llvmModule
, type
, op
.getConstant(), linkage
, cst
, op
.getSymName(),
1064 /*InsertBefore=*/nullptr,
1065 op
.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1066 : llvm::GlobalValue::NotThreadLocal
,
1067 op
.getAddrSpace(), op
.getExternallyInitialized());
1069 if (std::optional
<mlir::SymbolRefAttr
> comdat
= op
.getComdat()) {
1070 auto selectorOp
= cast
<ComdatSelectorOp
>(
1071 SymbolTable::lookupNearestSymbolFrom(op
, *comdat
));
1072 var
->setComdat(comdatMapping
.lookup(selectorOp
));
1075 if (op
.getUnnamedAddr().has_value())
1076 var
->setUnnamedAddr(convertUnnamedAddrToLLVM(*op
.getUnnamedAddr()));
1078 if (op
.getSection().has_value())
1079 var
->setSection(*op
.getSection());
1081 addRuntimePreemptionSpecifier(op
.getDsoLocal(), var
);
1083 std::optional
<uint64_t> alignment
= op
.getAlignment();
1084 if (alignment
.has_value())
1085 var
->setAlignment(llvm::MaybeAlign(alignment
.value()));
1087 var
->setVisibility(convertVisibilityToLLVM(op
.getVisibility_()));
1089 globalsMapping
.try_emplace(op
, var
);
1091 // Add debug information if present.
1092 if (op
.getDbgExprs()) {
1093 for (auto exprAttr
:
1094 op
.getDbgExprs()->getAsRange
<DIGlobalVariableExpressionAttr
>()) {
1095 llvm::DIGlobalVariableExpression
*diGlobalExpr
=
1096 debugTranslation
->translateGlobalVariableExpression(exprAttr
);
1097 llvm::DIGlobalVariable
*diGlobalVar
= diGlobalExpr
->getVariable();
1098 var
->addDebugInfo(diGlobalExpr
);
1100 // There is no `globals` field in DICompileUnitAttr which can be
1101 // directly assigned to DICompileUnit. We have to build the list by
1102 // looking at the dbgExpr of all the GlobalOps. The scope of the
1103 // variable is used to get the DICompileUnit in which to add it. But
1104 // there are cases where the scope of a global does not directly point
1105 // to the DICompileUnit and we have to do a bit more work to get to
1106 // it. Some of those cases are:
1108 // 1. For the languages that support modules, the scope hierarchy can
1109 // be variable -> DIModule -> DICompileUnit
1111 // 2. For the Fortran common block variable, the scope hierarchy can
1112 // be variable -> DICommonBlock -> DISubprogram -> DICompileUnit
1114 // 3. For entities like static local variables in C or variable with
1115 // SAVE attribute in Fortran, the scope hierarchy can be
1116 // variable -> DISubprogram -> DICompileUnit
1117 llvm::DIScope
*scope
= diGlobalVar
->getScope();
1118 if (auto *mod
= dyn_cast_if_present
<llvm::DIModule
>(scope
))
1119 scope
= mod
->getScope();
1120 else if (auto *cb
= dyn_cast_if_present
<llvm::DICommonBlock
>(scope
)) {
1122 dyn_cast_if_present
<llvm::DISubprogram
>(cb
->getScope()))
1123 scope
= sp
->getUnit();
1124 } else if (auto *sp
= dyn_cast_if_present
<llvm::DISubprogram
>(scope
))
1125 scope
= sp
->getUnit();
1127 // Get the compile unit (scope) of the the global variable.
1128 if (llvm::DICompileUnit
*compileUnit
=
1129 dyn_cast_if_present
<llvm::DICompileUnit
>(scope
)) {
1130 // Update the compile unit with this incoming global variable
1131 // expression during the finalizing step later.
1132 allGVars
[compileUnit
].push_back(diGlobalExpr
);
1138 // Convert global variable bodies. This is done after all global variables
1139 // have been created in LLVM IR because a global body may refer to another
1140 // global or itself. So all global variables need to be mapped first.
1141 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>()) {
1142 if (Block
*initializer
= op
.getInitializerBlock()) {
1143 llvm::IRBuilder
<> builder(llvmModule
->getContext());
1145 [[maybe_unused
]] int numConstantsHit
= 0;
1146 [[maybe_unused
]] int numConstantsErased
= 0;
1147 DenseMap
<llvm::ConstantAggregate
*, int> constantAggregateUseMap
;
1149 for (auto &op
: initializer
->without_terminator()) {
1150 if (failed(convertOperation(op
, builder
)))
1151 return emitError(op
.getLoc(), "fail to convert global initializer");
1152 auto *cst
= dyn_cast
<llvm::Constant
>(lookupValue(op
.getResult(0)));
1154 return emitError(op
.getLoc(), "unemittable constant value");
1156 // When emitting an LLVM constant, a new constant is created and the old
1157 // constant may become dangling and take space. We should remove the
1158 // dangling constants to avoid memory explosion especially for constant
1159 // arrays whose number of elements is large.
1160 // Because multiple operations may refer to the same constant, we need
1161 // to count the number of uses of each constant array and remove it only
1162 // when the count becomes zero.
1163 if (auto *agg
= dyn_cast
<llvm::ConstantAggregate
>(cst
)) {
1165 Value result
= op
.getResult(0);
1166 int numUsers
= std::distance(result
.use_begin(), result
.use_end());
1167 auto [iterator
, inserted
] =
1168 constantAggregateUseMap
.try_emplace(agg
, numUsers
);
1170 // Key already exists, update the value
1171 iterator
->second
+= numUsers
;
1174 // Scan the operands of the operation to decrement the use count of
1175 // constants. Erase the constant if the use count becomes zero.
1176 for (Value v
: op
.getOperands()) {
1177 auto cst
= dyn_cast
<llvm::ConstantAggregate
>(lookupValue(v
));
1180 auto iter
= constantAggregateUseMap
.find(cst
);
1181 assert(iter
!= constantAggregateUseMap
.end() && "constant not found");
1183 if (iter
->second
== 0) {
1184 // NOTE: cannot call removeDeadConstantUsers() here because it
1185 // may remove the constant which has uses not be converted yet.
1186 if (cst
->user_empty()) {
1187 cst
->destroyConstant();
1188 numConstantsErased
++;
1190 constantAggregateUseMap
.erase(iter
);
1195 ReturnOp ret
= cast
<ReturnOp
>(initializer
->getTerminator());
1196 llvm::Constant
*cst
=
1197 cast
<llvm::Constant
>(lookupValue(ret
.getOperand(0)));
1198 auto *global
= cast
<llvm::GlobalVariable
>(lookupGlobal(op
));
1199 if (!shouldDropGlobalInitializer(global
->getLinkage(), cst
))
1200 global
->setInitializer(cst
);
1202 // Try to remove the dangling constants again after all operations are
1204 for (auto it
: constantAggregateUseMap
) {
1205 auto cst
= it
.first
;
1206 cst
->removeDeadConstantUsers();
1207 if (cst
->user_empty()) {
1208 cst
->destroyConstant();
1209 numConstantsErased
++;
1213 LLVM_DEBUG(llvm::dbgs()
1214 << "Convert initializer for " << op
.getName() << "\n";
1215 llvm::dbgs() << numConstantsHit
<< " new constants hit\n";
1217 << numConstantsErased
<< " dangling constants erased\n";);
1221 // Convert llvm.mlir.global_ctors and dtors.
1222 for (Operation
&op
: getModuleBody(mlirModule
)) {
1223 auto ctorOp
= dyn_cast
<GlobalCtorsOp
>(op
);
1224 auto dtorOp
= dyn_cast
<GlobalDtorsOp
>(op
);
1225 if (!ctorOp
&& !dtorOp
)
1227 auto range
= ctorOp
? llvm::zip(ctorOp
.getCtors(), ctorOp
.getPriorities())
1228 : llvm::zip(dtorOp
.getDtors(), dtorOp
.getPriorities());
1229 auto appendGlobalFn
=
1230 ctorOp
? llvm::appendToGlobalCtors
: llvm::appendToGlobalDtors
;
1231 for (auto symbolAndPriority
: range
) {
1232 llvm::Function
*f
= lookupFunction(
1233 cast
<FlatSymbolRefAttr
>(std::get
<0>(symbolAndPriority
)).getValue());
1234 appendGlobalFn(*llvmModule
, f
,
1235 cast
<IntegerAttr
>(std::get
<1>(symbolAndPriority
)).getInt(),
1240 for (auto op
: getModuleBody(mlirModule
).getOps
<LLVM::GlobalOp
>())
1241 if (failed(convertDialectAttributes(op
, {})))
1244 // Finally, update the compile units their respective sets of global variables
1246 for (const auto &[compileUnit
, globals
] : allGVars
) {
1247 compileUnit
->replaceGlobalVariables(
1248 llvm::MDTuple::get(getLLVMContext(), globals
));
1254 /// Attempts to add an attribute identified by `key`, optionally with the given
1255 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
1256 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
1257 /// otherwise keep it as a string attribute. Performs additional checks for
1258 /// attributes known to have or not have a value in order to avoid assertions
1259 /// inside LLVM upon construction.
1260 static LogicalResult
checkedAddLLVMFnAttribute(Location loc
,
1261 llvm::Function
*llvmFunc
,
1263 StringRef value
= StringRef()) {
1264 auto kind
= llvm::Attribute::getAttrKindFromName(key
);
1265 if (kind
== llvm::Attribute::None
) {
1266 llvmFunc
->addFnAttr(key
, value
);
1270 if (llvm::Attribute::isIntAttrKind(kind
)) {
1272 return emitError(loc
) << "LLVM attribute '" << key
<< "' expects a value";
1275 if (!value
.getAsInteger(/*Radix=*/0, result
))
1276 llvmFunc
->addFnAttr(
1277 llvm::Attribute::get(llvmFunc
->getContext(), kind
, result
));
1279 llvmFunc
->addFnAttr(key
, value
);
1284 return emitError(loc
) << "LLVM attribute '" << key
1285 << "' does not expect a value, found '" << value
1288 llvmFunc
->addFnAttr(kind
);
1292 /// Return a representation of `value` as metadata.
1293 static llvm::Metadata
*convertIntegerToMetadata(llvm::LLVMContext
&context
,
1294 const llvm::APInt
&value
) {
1295 llvm::Constant
*constant
= llvm::ConstantInt::get(context
, value
);
1296 return llvm::ConstantAsMetadata::get(constant
);
1299 /// Return a representation of `value` as an MDNode.
1300 static llvm::MDNode
*convertIntegerToMDNode(llvm::LLVMContext
&context
,
1301 const llvm::APInt
&value
) {
1302 return llvm::MDNode::get(context
, convertIntegerToMetadata(context
, value
));
1305 /// Return an MDNode encoding `vec_type_hint` metadata.
1306 static llvm::MDNode
*convertVecTypeHintToMDNode(llvm::LLVMContext
&context
,
1309 llvm::Metadata
*typeMD
=
1310 llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type
));
1311 llvm::Metadata
*isSignedMD
=
1312 convertIntegerToMetadata(context
, llvm::APInt(32, isSigned
? 1 : 0));
1313 return llvm::MDNode::get(context
, {typeMD
, isSignedMD
});
1316 /// Return an MDNode with a tuple given by the values in `values`.
1317 static llvm::MDNode
*convertIntegerArrayToMDNode(llvm::LLVMContext
&context
,
1318 ArrayRef
<int32_t> values
) {
1319 SmallVector
<llvm::Metadata
*> mdValues
;
1321 values
, std::back_inserter(mdValues
), [&context
](int32_t value
) {
1322 return convertIntegerToMetadata(context
, llvm::APInt(32, value
));
1324 return llvm::MDNode::get(context
, mdValues
);
1327 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1328 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
1329 /// to be an array attribute containing either string attributes, treated as
1330 /// value-less LLVM attributes, or array attributes containing two string
1331 /// attributes, with the first string being the name of the corresponding LLVM
1332 /// attribute and the second string beings its value. Note that even integer
1333 /// attributes are expected to have their values expressed as strings.
1334 static LogicalResult
1335 forwardPassthroughAttributes(Location loc
, std::optional
<ArrayAttr
> attributes
,
1336 llvm::Function
*llvmFunc
) {
1340 for (Attribute attr
: *attributes
) {
1341 if (auto stringAttr
= dyn_cast
<StringAttr
>(attr
)) {
1343 checkedAddLLVMFnAttribute(loc
, llvmFunc
, stringAttr
.getValue())))
1348 auto arrayAttr
= dyn_cast
<ArrayAttr
>(attr
);
1349 if (!arrayAttr
|| arrayAttr
.size() != 2)
1350 return emitError(loc
)
1351 << "expected 'passthrough' to contain string or array attributes";
1353 auto keyAttr
= dyn_cast
<StringAttr
>(arrayAttr
[0]);
1354 auto valueAttr
= dyn_cast
<StringAttr
>(arrayAttr
[1]);
1355 if (!keyAttr
|| !valueAttr
)
1356 return emitError(loc
)
1357 << "expected arrays within 'passthrough' to contain two strings";
1359 if (failed(checkedAddLLVMFnAttribute(loc
, llvmFunc
, keyAttr
.getValue(),
1360 valueAttr
.getValue())))
1366 LogicalResult
ModuleTranslation::convertOneFunction(LLVMFuncOp func
) {
1367 // Clear the block, branch value mappings, they are only relevant within one
1369 blockMapping
.clear();
1370 valueMapping
.clear();
1371 branchMapping
.clear();
1372 llvm::Function
*llvmFunc
= lookupFunction(func
.getName());
1374 // Add function arguments to the value remapping table.
1375 for (auto [mlirArg
, llvmArg
] :
1376 llvm::zip(func
.getArguments(), llvmFunc
->args()))
1377 mapValue(mlirArg
, &llvmArg
);
1379 // Check the personality and set it.
1380 if (func
.getPersonality()) {
1381 llvm::Type
*ty
= llvm::PointerType::getUnqual(llvmFunc
->getContext());
1382 if (llvm::Constant
*pfunc
= getLLVMConstant(ty
, func
.getPersonalityAttr(),
1383 func
.getLoc(), *this))
1384 llvmFunc
->setPersonalityFn(pfunc
);
1387 if (std::optional
<StringRef
> section
= func
.getSection())
1388 llvmFunc
->setSection(*section
);
1390 if (func
.getArmStreaming())
1391 llvmFunc
->addFnAttr("aarch64_pstate_sm_enabled");
1392 else if (func
.getArmLocallyStreaming())
1393 llvmFunc
->addFnAttr("aarch64_pstate_sm_body");
1394 else if (func
.getArmStreamingCompatible())
1395 llvmFunc
->addFnAttr("aarch64_pstate_sm_compatible");
1397 if (func
.getArmNewZa())
1398 llvmFunc
->addFnAttr("aarch64_new_za");
1399 else if (func
.getArmInZa())
1400 llvmFunc
->addFnAttr("aarch64_in_za");
1401 else if (func
.getArmOutZa())
1402 llvmFunc
->addFnAttr("aarch64_out_za");
1403 else if (func
.getArmInoutZa())
1404 llvmFunc
->addFnAttr("aarch64_inout_za");
1405 else if (func
.getArmPreservesZa())
1406 llvmFunc
->addFnAttr("aarch64_preserves_za");
1408 if (auto targetCpu
= func
.getTargetCpu())
1409 llvmFunc
->addFnAttr("target-cpu", *targetCpu
);
1411 if (auto tuneCpu
= func
.getTuneCpu())
1412 llvmFunc
->addFnAttr("tune-cpu", *tuneCpu
);
1414 if (auto targetFeatures
= func
.getTargetFeatures())
1415 llvmFunc
->addFnAttr("target-features", targetFeatures
->getFeaturesString());
1417 if (auto attr
= func
.getVscaleRange())
1418 llvmFunc
->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
1419 getLLVMContext(), attr
->getMinRange().getInt(),
1420 attr
->getMaxRange().getInt()));
1422 if (auto unsafeFpMath
= func
.getUnsafeFpMath())
1423 llvmFunc
->addFnAttr("unsafe-fp-math", llvm::toStringRef(*unsafeFpMath
));
1425 if (auto noInfsFpMath
= func
.getNoInfsFpMath())
1426 llvmFunc
->addFnAttr("no-infs-fp-math", llvm::toStringRef(*noInfsFpMath
));
1428 if (auto noNansFpMath
= func
.getNoNansFpMath())
1429 llvmFunc
->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath
));
1431 if (auto approxFuncFpMath
= func
.getApproxFuncFpMath())
1432 llvmFunc
->addFnAttr("approx-func-fp-math",
1433 llvm::toStringRef(*approxFuncFpMath
));
1435 if (auto noSignedZerosFpMath
= func
.getNoSignedZerosFpMath())
1436 llvmFunc
->addFnAttr("no-signed-zeros-fp-math",
1437 llvm::toStringRef(*noSignedZerosFpMath
));
1439 if (auto denormalFpMath
= func
.getDenormalFpMath())
1440 llvmFunc
->addFnAttr("denormal-fp-math", *denormalFpMath
);
1442 if (auto denormalFpMathF32
= func
.getDenormalFpMathF32())
1443 llvmFunc
->addFnAttr("denormal-fp-math-f32", *denormalFpMathF32
);
1445 if (auto fpContract
= func
.getFpContract())
1446 llvmFunc
->addFnAttr("fp-contract", *fpContract
);
1448 // Add function attribute frame-pointer, if found.
1449 if (FramePointerKindAttr attr
= func
.getFramePointerAttr())
1450 llvmFunc
->addFnAttr("frame-pointer",
1451 LLVM::framePointerKind::stringifyFramePointerKind(
1452 (attr
.getFramePointerKind())));
1454 // First, create all blocks so we can jump to them.
1455 llvm::LLVMContext
&llvmContext
= llvmFunc
->getContext();
1456 for (auto &bb
: func
) {
1457 auto *llvmBB
= llvm::BasicBlock::Create(llvmContext
);
1458 llvmBB
->insertInto(llvmFunc
);
1459 mapBlock(&bb
, llvmBB
);
1462 // Then, convert blocks one by one in topological order to ensure defs are
1463 // converted before uses.
1464 auto blocks
= getBlocksSortedByDominance(func
.getBody());
1465 for (Block
*bb
: blocks
) {
1466 CapturingIRBuilder
builder(llvmContext
);
1467 if (failed(convertBlockImpl(*bb
, bb
->isEntryBlock(), builder
,
1468 /*recordInsertions=*/true)))
1472 // After all blocks have been traversed and values mapped, connect the PHI
1473 // nodes to the results of preceding blocks.
1474 detail::connectPHINodes(func
.getBody(), *this);
1476 // Finally, convert dialect attributes attached to the function.
1477 return convertDialectAttributes(func
, {});
1480 LogicalResult
ModuleTranslation::convertDialectAttributes(
1481 Operation
*op
, ArrayRef
<llvm::Instruction
*> instructions
) {
1482 for (NamedAttribute attribute
: op
->getDialectAttrs())
1483 if (failed(iface
.amendOperation(op
, instructions
, attribute
, *this)))
1488 /// Converts memory effect attributes from `func` and attaches them to
1490 static void convertFunctionMemoryAttributes(LLVMFuncOp func
,
1491 llvm::Function
*llvmFunc
) {
1492 if (!func
.getMemoryEffects())
1495 MemoryEffectsAttr memEffects
= func
.getMemoryEffectsAttr();
1497 // Add memory effects incrementally.
1498 llvm::MemoryEffects newMemEffects
=
1499 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem
,
1500 convertModRefInfoToLLVM(memEffects
.getArgMem()));
1501 newMemEffects
|= llvm::MemoryEffects(
1502 llvm::MemoryEffects::Location::InaccessibleMem
,
1503 convertModRefInfoToLLVM(memEffects
.getInaccessibleMem()));
1505 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other
,
1506 convertModRefInfoToLLVM(memEffects
.getOther()));
1507 llvmFunc
->setMemoryEffects(newMemEffects
);
1510 /// Converts function attributes from `func` and attaches them to `llvmFunc`.
1511 static void convertFunctionAttributes(LLVMFuncOp func
,
1512 llvm::Function
*llvmFunc
) {
1513 if (func
.getNoInlineAttr())
1514 llvmFunc
->addFnAttr(llvm::Attribute::NoInline
);
1515 if (func
.getAlwaysInlineAttr())
1516 llvmFunc
->addFnAttr(llvm::Attribute::AlwaysInline
);
1517 if (func
.getOptimizeNoneAttr())
1518 llvmFunc
->addFnAttr(llvm::Attribute::OptimizeNone
);
1519 if (func
.getConvergentAttr())
1520 llvmFunc
->addFnAttr(llvm::Attribute::Convergent
);
1521 if (func
.getNoUnwindAttr())
1522 llvmFunc
->addFnAttr(llvm::Attribute::NoUnwind
);
1523 if (func
.getWillReturnAttr())
1524 llvmFunc
->addFnAttr(llvm::Attribute::WillReturn
);
1525 convertFunctionMemoryAttributes(func
, llvmFunc
);
1528 /// Converts function attributes from `func` and attaches them to `llvmFunc`.
1529 static void convertFunctionKernelAttributes(LLVMFuncOp func
,
1530 llvm::Function
*llvmFunc
,
1531 ModuleTranslation
&translation
) {
1532 llvm::LLVMContext
&llvmContext
= llvmFunc
->getContext();
1534 if (VecTypeHintAttr vecTypeHint
= func
.getVecTypeHintAttr()) {
1535 Type type
= vecTypeHint
.getHint().getValue();
1536 llvm::Type
*llvmType
= translation
.convertType(type
);
1537 bool isSigned
= vecTypeHint
.getIsSigned();
1538 llvmFunc
->setMetadata(
1539 func
.getVecTypeHintAttrName(),
1540 convertVecTypeHintToMDNode(llvmContext
, llvmType
, isSigned
));
1543 if (std::optional
<ArrayRef
<int32_t>> workGroupSizeHint
=
1544 func
.getWorkGroupSizeHint()) {
1545 llvmFunc
->setMetadata(
1546 func
.getWorkGroupSizeHintAttrName(),
1547 convertIntegerArrayToMDNode(llvmContext
, *workGroupSizeHint
));
1550 if (std::optional
<ArrayRef
<int32_t>> reqdWorkGroupSize
=
1551 func
.getReqdWorkGroupSize()) {
1552 llvmFunc
->setMetadata(
1553 func
.getReqdWorkGroupSizeAttrName(),
1554 convertIntegerArrayToMDNode(llvmContext
, *reqdWorkGroupSize
));
1557 if (std::optional
<uint32_t> intelReqdSubGroupSize
=
1558 func
.getIntelReqdSubGroupSize()) {
1559 llvmFunc
->setMetadata(
1560 func
.getIntelReqdSubGroupSizeAttrName(),
1561 convertIntegerToMDNode(llvmContext
,
1562 llvm::APInt(32, *intelReqdSubGroupSize
)));
1566 FailureOr
<llvm::AttrBuilder
>
1567 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func
, int argIdx
,
1568 DictionaryAttr paramAttrs
) {
1569 llvm::AttrBuilder
attrBuilder(llvmModule
->getContext());
1570 auto attrNameToKindMapping
= getAttrNameToKindMapping();
1572 for (auto namedAttr
: paramAttrs
) {
1573 auto it
= attrNameToKindMapping
.find(namedAttr
.getName());
1574 if (it
!= attrNameToKindMapping
.end()) {
1575 llvm::Attribute::AttrKind llvmKind
= it
->second
;
1577 llvm::TypeSwitch
<Attribute
>(namedAttr
.getValue())
1578 .Case
<TypeAttr
>([&](auto typeAttr
) {
1579 attrBuilder
.addTypeAttr(llvmKind
, convertType(typeAttr
.getValue()));
1581 .Case
<IntegerAttr
>([&](auto intAttr
) {
1582 attrBuilder
.addRawIntAttr(llvmKind
, intAttr
.getInt());
1584 .Case
<UnitAttr
>([&](auto) { attrBuilder
.addAttribute(llvmKind
); })
1585 .Case
<LLVM::ConstantRangeAttr
>([&](auto rangeAttr
) {
1586 attrBuilder
.addConstantRangeAttr(
1587 llvmKind
, llvm::ConstantRange(rangeAttr
.getLower(),
1588 rangeAttr
.getUpper()));
1590 } else if (namedAttr
.getNameDialect()) {
1591 if (failed(iface
.convertParameterAttr(func
, argIdx
, namedAttr
, *this)))
1599 LogicalResult
ModuleTranslation::convertFunctionSignatures() {
1600 // Declare all functions first because there may be function calls that form a
1601 // call graph with cycles, or global initializers that reference functions.
1602 for (auto function
: getModuleBody(mlirModule
).getOps
<LLVMFuncOp
>()) {
1603 llvm::FunctionCallee llvmFuncCst
= llvmModule
->getOrInsertFunction(
1605 cast
<llvm::FunctionType
>(convertType(function
.getFunctionType())));
1606 llvm::Function
*llvmFunc
= cast
<llvm::Function
>(llvmFuncCst
.getCallee());
1607 llvmFunc
->setLinkage(convertLinkageToLLVM(function
.getLinkage()));
1608 llvmFunc
->setCallingConv(convertCConvToLLVM(function
.getCConv()));
1609 mapFunction(function
.getName(), llvmFunc
);
1610 addRuntimePreemptionSpecifier(function
.getDsoLocal(), llvmFunc
);
1612 // Convert function attributes.
1613 convertFunctionAttributes(function
, llvmFunc
);
1615 // Convert function kernel attributes to metadata.
1616 convertFunctionKernelAttributes(function
, llvmFunc
, *this);
1618 // Convert function_entry_count attribute to metadata.
1619 if (std::optional
<uint64_t> entryCount
= function
.getFunctionEntryCount())
1620 llvmFunc
->setEntryCount(entryCount
.value());
1622 // Convert result attributes.
1623 if (ArrayAttr allResultAttrs
= function
.getAllResultAttrs()) {
1624 DictionaryAttr resultAttrs
= cast
<DictionaryAttr
>(allResultAttrs
[0]);
1625 FailureOr
<llvm::AttrBuilder
> attrBuilder
=
1626 convertParameterAttrs(function
, -1, resultAttrs
);
1627 if (failed(attrBuilder
))
1629 llvmFunc
->addRetAttrs(*attrBuilder
);
1632 // Convert argument attributes.
1633 for (auto [argIdx
, llvmArg
] : llvm::enumerate(llvmFunc
->args())) {
1634 if (DictionaryAttr argAttrs
= function
.getArgAttrDict(argIdx
)) {
1635 FailureOr
<llvm::AttrBuilder
> attrBuilder
=
1636 convertParameterAttrs(function
, argIdx
, argAttrs
);
1637 if (failed(attrBuilder
))
1639 llvmArg
.addAttrs(*attrBuilder
);
1643 // Forward the pass-through attributes to LLVM.
1644 if (failed(forwardPassthroughAttributes(
1645 function
.getLoc(), function
.getPassthrough(), llvmFunc
)))
1648 // Convert visibility attribute.
1649 llvmFunc
->setVisibility(convertVisibilityToLLVM(function
.getVisibility_()));
1651 // Convert the comdat attribute.
1652 if (std::optional
<mlir::SymbolRefAttr
> comdat
= function
.getComdat()) {
1653 auto selectorOp
= cast
<ComdatSelectorOp
>(
1654 SymbolTable::lookupNearestSymbolFrom(function
, *comdat
));
1655 llvmFunc
->setComdat(comdatMapping
.lookup(selectorOp
));
1658 if (auto gc
= function
.getGarbageCollector())
1659 llvmFunc
->setGC(gc
->str());
1661 if (auto unnamedAddr
= function
.getUnnamedAddr())
1662 llvmFunc
->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr
));
1664 if (auto alignment
= function
.getAlignment())
1665 llvmFunc
->setAlignment(llvm::MaybeAlign(*alignment
));
1667 // Translate the debug information for this function.
1668 debugTranslation
->translate(function
, *llvmFunc
);
1674 LogicalResult
ModuleTranslation::convertFunctions() {
1675 // Convert functions.
1676 for (auto function
: getModuleBody(mlirModule
).getOps
<LLVMFuncOp
>()) {
1677 // Do not convert external functions, but do process dialect attributes
1678 // attached to them.
1679 if (function
.isExternal()) {
1680 if (failed(convertDialectAttributes(function
, {})))
1685 if (failed(convertOneFunction(function
)))
1692 LogicalResult
ModuleTranslation::convertComdats() {
1693 for (auto comdatOp
: getModuleBody(mlirModule
).getOps
<ComdatOp
>()) {
1694 for (auto selectorOp
: comdatOp
.getOps
<ComdatSelectorOp
>()) {
1695 llvm::Module
*module
= getLLVMModule();
1696 if (module
->getComdatSymbolTable().contains(selectorOp
.getSymName()))
1697 return emitError(selectorOp
.getLoc())
1698 << "comdat selection symbols must be unique even in different "
1700 llvm::Comdat
*comdat
= module
->getOrInsertComdat(selectorOp
.getSymName());
1701 comdat
->setSelectionKind(convertComdatToLLVM(selectorOp
.getComdat()));
1702 comdatMapping
.try_emplace(selectorOp
, comdat
);
1708 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op
,
1709 llvm::Instruction
*inst
) {
1710 if (llvm::MDNode
*node
= loopAnnotationTranslation
->getAccessGroups(op
))
1711 inst
->setMetadata(llvm::LLVMContext::MD_access_group
, node
);
1715 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr
) {
1716 auto [scopeIt
, scopeInserted
] =
1717 aliasScopeMetadataMapping
.try_emplace(aliasScopeAttr
, nullptr);
1719 return scopeIt
->second
;
1720 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1721 auto dummy
= llvm::MDNode::getTemporary(ctx
, std::nullopt
);
1722 // Convert the domain metadata node if necessary.
1723 auto [domainIt
, insertedDomain
] = aliasDomainMetadataMapping
.try_emplace(
1724 aliasScopeAttr
.getDomain(), nullptr);
1725 if (insertedDomain
) {
1726 llvm::SmallVector
<llvm::Metadata
*, 2> operands
;
1727 // Placeholder for self-reference.
1728 operands
.push_back(dummy
.get());
1729 if (StringAttr description
= aliasScopeAttr
.getDomain().getDescription())
1730 operands
.push_back(llvm::MDString::get(ctx
, description
));
1731 domainIt
->second
= llvm::MDNode::get(ctx
, operands
);
1732 // Self-reference for uniqueness.
1733 domainIt
->second
->replaceOperandWith(0, domainIt
->second
);
1735 // Convert the scope metadata node.
1736 assert(domainIt
->second
&& "Scope's domain should already be valid");
1737 llvm::SmallVector
<llvm::Metadata
*, 3> operands
;
1738 // Placeholder for self-reference.
1739 operands
.push_back(dummy
.get());
1740 operands
.push_back(domainIt
->second
);
1741 if (StringAttr description
= aliasScopeAttr
.getDescription())
1742 operands
.push_back(llvm::MDString::get(ctx
, description
));
1743 scopeIt
->second
= llvm::MDNode::get(ctx
, operands
);
1744 // Self-reference for uniqueness.
1745 scopeIt
->second
->replaceOperandWith(0, scopeIt
->second
);
1746 return scopeIt
->second
;
1749 llvm::MDNode
*ModuleTranslation::getOrCreateAliasScopes(
1750 ArrayRef
<AliasScopeAttr
> aliasScopeAttrs
) {
1751 SmallVector
<llvm::Metadata
*> nodes
;
1752 nodes
.reserve(aliasScopeAttrs
.size());
1753 for (AliasScopeAttr aliasScopeAttr
: aliasScopeAttrs
)
1754 nodes
.push_back(getOrCreateAliasScope(aliasScopeAttr
));
1755 return llvm::MDNode::get(getLLVMContext(), nodes
);
1758 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op
,
1759 llvm::Instruction
*inst
) {
1760 auto populateScopeMetadata
= [&](ArrayAttr aliasScopeAttrs
, unsigned kind
) {
1761 if (!aliasScopeAttrs
|| aliasScopeAttrs
.empty())
1763 llvm::MDNode
*node
= getOrCreateAliasScopes(
1764 llvm::to_vector(aliasScopeAttrs
.getAsRange
<AliasScopeAttr
>()));
1765 inst
->setMetadata(kind
, node
);
1768 populateScopeMetadata(op
.getAliasScopesOrNull(),
1769 llvm::LLVMContext::MD_alias_scope
);
1770 populateScopeMetadata(op
.getNoAliasScopesOrNull(),
1771 llvm::LLVMContext::MD_noalias
);
1774 llvm::MDNode
*ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr
) const {
1775 return tbaaMetadataMapping
.lookup(tbaaAttr
);
1778 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op
,
1779 llvm::Instruction
*inst
) {
1780 ArrayAttr tagRefs
= op
.getTBAATagsOrNull();
1781 if (!tagRefs
|| tagRefs
.empty())
1784 // LLVM IR currently does not support attaching more than one TBAA access tag
1785 // to a memory accessing instruction. It may be useful to support this in
1786 // future, but for the time being just ignore the metadata if MLIR operation
1787 // has multiple access tags.
1788 if (tagRefs
.size() > 1) {
1789 op
.emitWarning() << "TBAA access tags were not translated, because LLVM "
1790 "IR only supports a single tag per instruction";
1794 llvm::MDNode
*node
= getTBAANode(cast
<TBAATagAttr
>(tagRefs
[0]));
1795 inst
->setMetadata(llvm::LLVMContext::MD_tbaa
, node
);
1798 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op
) {
1799 DenseI32ArrayAttr weightsAttr
= op
.getBranchWeightsOrNull();
1803 llvm::Instruction
*inst
= isa
<CallOp
>(op
) ? lookupCall(op
) : lookupBranch(op
);
1804 assert(inst
&& "expected the operation to have a mapping to an instruction");
1805 SmallVector
<uint32_t> weights(weightsAttr
.asArrayRef());
1807 llvm::LLVMContext::MD_prof
,
1808 llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights
));
1811 LogicalResult
ModuleTranslation::createTBAAMetadata() {
1812 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1813 llvm::IntegerType
*offsetTy
= llvm::IntegerType::get(ctx
, 64);
1815 // Walk the entire module and create all metadata nodes for the TBAA
1816 // attributes. The code below relies on two invariants of the
1817 // `AttrTypeWalker`:
1818 // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1819 // this ensures that any lookups into `tbaaMetadataMapping` for child
1820 // attributes succeed.
1821 // 2. Attributes are only ever visited once: This way we don't leak any
1822 // LLVM metadata instances.
1823 AttrTypeWalker walker
;
1824 walker
.addWalk([&](TBAARootAttr root
) {
1825 tbaaMetadataMapping
.insert(
1826 {root
, llvm::MDNode::get(ctx
, llvm::MDString::get(ctx
, root
.getId()))});
1829 walker
.addWalk([&](TBAATypeDescriptorAttr descriptor
) {
1830 SmallVector
<llvm::Metadata
*> operands
;
1831 operands
.push_back(llvm::MDString::get(ctx
, descriptor
.getId()));
1832 for (TBAAMemberAttr member
: descriptor
.getMembers()) {
1833 operands
.push_back(tbaaMetadataMapping
.lookup(member
.getTypeDesc()));
1834 operands
.push_back(llvm::ConstantAsMetadata::get(
1835 llvm::ConstantInt::get(offsetTy
, member
.getOffset())));
1838 tbaaMetadataMapping
.insert({descriptor
, llvm::MDNode::get(ctx
, operands
)});
1841 walker
.addWalk([&](TBAATagAttr tag
) {
1842 SmallVector
<llvm::Metadata
*> operands
;
1844 operands
.push_back(tbaaMetadataMapping
.lookup(tag
.getBaseType()));
1845 operands
.push_back(tbaaMetadataMapping
.lookup(tag
.getAccessType()));
1847 operands
.push_back(llvm::ConstantAsMetadata::get(
1848 llvm::ConstantInt::get(offsetTy
, tag
.getOffset())));
1849 if (tag
.getConstant())
1851 llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(offsetTy
, 1)));
1853 tbaaMetadataMapping
.insert({tag
, llvm::MDNode::get(ctx
, operands
)});
1856 mlirModule
->walk([&](AliasAnalysisOpInterface analysisOpInterface
) {
1857 if (auto attr
= analysisOpInterface
.getTBAATagsOrNull())
1864 LogicalResult
ModuleTranslation::createIdentMetadata() {
1865 if (auto attr
= mlirModule
->getAttrOfType
<StringAttr
>(
1866 LLVMDialect::getIdentAttrName())) {
1867 StringRef ident
= attr
;
1868 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1869 llvm::NamedMDNode
*namedMd
=
1870 llvmModule
->getOrInsertNamedMetadata(LLVMDialect::getIdentAttrName());
1871 llvm::MDNode
*md
= llvm::MDNode::get(ctx
, llvm::MDString::get(ctx
, ident
));
1872 namedMd
->addOperand(md
);
1878 LogicalResult
ModuleTranslation::createCommandlineMetadata() {
1879 if (auto attr
= mlirModule
->getAttrOfType
<StringAttr
>(
1880 LLVMDialect::getCommandlineAttrName())) {
1881 StringRef cmdLine
= attr
;
1882 llvm::LLVMContext
&ctx
= llvmModule
->getContext();
1883 llvm::NamedMDNode
*nmd
= llvmModule
->getOrInsertNamedMetadata(
1884 LLVMDialect::getCommandlineAttrName());
1886 llvm::MDNode::get(ctx
, llvm::MDString::get(ctx
, cmdLine
));
1887 nmd
->addOperand(md
);
1893 void ModuleTranslation::setLoopMetadata(Operation
*op
,
1894 llvm::Instruction
*inst
) {
1895 LoopAnnotationAttr attr
=
1896 TypeSwitch
<Operation
*, LoopAnnotationAttr
>(op
)
1897 .Case
<LLVM::BrOp
, LLVM::CondBrOp
>(
1898 [](auto branchOp
) { return branchOp
.getLoopAnnotationAttr(); });
1901 llvm::MDNode
*loopMD
=
1902 loopAnnotationTranslation
->translateLoopAnnotation(attr
, op
);
1903 inst
->setMetadata(llvm::LLVMContext::MD_loop
, loopMD
);
1906 void ModuleTranslation::setDisjointFlag(Operation
*op
, llvm::Value
*value
) {
1907 auto iface
= cast
<DisjointFlagInterface
>(op
);
1908 // We do a dyn_cast here in case the value got folded into a constant.
1909 if (auto disjointInst
= dyn_cast
<llvm::PossiblyDisjointInst
>(value
))
1910 disjointInst
->setIsDisjoint(iface
.getIsDisjoint());
1913 llvm::Type
*ModuleTranslation::convertType(Type type
) {
1914 return typeTranslator
.translateType(type
);
1917 /// A helper to look up remapped operands in the value remapping table.
1918 SmallVector
<llvm::Value
*> ModuleTranslation::lookupValues(ValueRange values
) {
1919 SmallVector
<llvm::Value
*> remapped
;
1920 remapped
.reserve(values
.size());
1921 for (Value v
: values
)
1922 remapped
.push_back(lookupValue(v
));
1926 llvm::OpenMPIRBuilder
*ModuleTranslation::getOpenMPBuilder() {
1928 ompBuilder
= std::make_unique
<llvm::OpenMPIRBuilder
>(*llvmModule
);
1929 ompBuilder
->initialize();
1931 // Flags represented as top-level OpenMP dialect attributes are set in
1932 // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
1933 // the default configuration.
1934 ompBuilder
->setConfig(llvm::OpenMPIRBuilderConfig(
1935 /* IsTargetDevice = */ false, /* IsGPU = */ false,
1936 /* OpenMPOffloadMandatory = */ false,
1937 /* HasRequiresReverseOffload = */ false,
1938 /* HasRequiresUnifiedAddress = */ false,
1939 /* HasRequiresUnifiedSharedMemory = */ false,
1940 /* HasRequiresDynamicAllocators = */ false));
1942 return ompBuilder
.get();
1945 llvm::DILocation
*ModuleTranslation::translateLoc(Location loc
,
1946 llvm::DILocalScope
*scope
) {
1947 return debugTranslation
->translateLoc(loc
, scope
);
1950 llvm::DIExpression
*
1951 ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr
) {
1952 return debugTranslation
->translateExpression(attr
);
1955 llvm::DIGlobalVariableExpression
*
1956 ModuleTranslation::translateGlobalVariableExpression(
1957 LLVM::DIGlobalVariableExpressionAttr attr
) {
1958 return debugTranslation
->translateGlobalVariableExpression(attr
);
1961 llvm::Metadata
*ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr
) {
1962 return debugTranslation
->translate(attr
);
1966 ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding
) {
1967 return convertRoundingModeToLLVM(rounding
);
1970 llvm::fp::ExceptionBehavior
ModuleTranslation::translateFPExceptionBehavior(
1971 LLVM::FPExceptionBehavior exceptionBehavior
) {
1972 return convertFPExceptionBehaviorToLLVM(exceptionBehavior
);
1976 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name
) {
1977 return llvmModule
->getOrInsertNamedMetadata(name
);
1980 void ModuleTranslation::StackFrame::anchor() {}
1982 static std::unique_ptr
<llvm::Module
>
1983 prepareLLVMModule(Operation
*m
, llvm::LLVMContext
&llvmContext
,
1985 m
->getContext()->getOrLoadDialect
<LLVM::LLVMDialect
>();
1986 auto llvmModule
= std::make_unique
<llvm::Module
>(name
, llvmContext
);
1987 // ModuleTranslation can currently only construct modules in the old debug
1988 // info format, so set the flag accordingly.
1989 llvmModule
->setNewDbgInfoFormatFlag(false);
1990 if (auto dataLayoutAttr
=
1991 m
->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1992 llvmModule
->setDataLayout(cast
<StringAttr
>(dataLayoutAttr
).getValue());
1994 FailureOr
<llvm::DataLayout
> llvmDataLayout(llvm::DataLayout(""));
1995 if (auto iface
= dyn_cast
<DataLayoutOpInterface
>(m
)) {
1996 if (DataLayoutSpecInterface spec
= iface
.getDataLayoutSpec()) {
1998 translateDataLayout(spec
, DataLayout(iface
), m
->getLoc());
2000 } else if (auto mod
= dyn_cast
<ModuleOp
>(m
)) {
2001 if (DataLayoutSpecInterface spec
= mod
.getDataLayoutSpec()) {
2003 translateDataLayout(spec
, DataLayout(mod
), m
->getLoc());
2006 if (failed(llvmDataLayout
))
2008 llvmModule
->setDataLayout(*llvmDataLayout
);
2010 if (auto targetTripleAttr
=
2011 m
->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
2012 llvmModule
->setTargetTriple(cast
<StringAttr
>(targetTripleAttr
).getValue());
2017 std::unique_ptr
<llvm::Module
>
2018 mlir::translateModuleToLLVMIR(Operation
*module
, llvm::LLVMContext
&llvmContext
,
2019 StringRef name
, bool disableVerification
) {
2020 if (!satisfiesLLVMModule(module
)) {
2021 module
->emitOpError("can not be translated to an LLVMIR module");
2025 std::unique_ptr
<llvm::Module
> llvmModule
=
2026 prepareLLVMModule(module
, llvmContext
, name
);
2030 LLVM::ensureDistinctSuccessors(module
);
2031 LLVM::legalizeDIExpressionsRecursively(module
);
2033 ModuleTranslation
translator(module
, std::move(llvmModule
));
2034 llvm::IRBuilder
<> llvmBuilder(llvmContext
);
2036 // Convert module before functions and operations inside, so dialect
2037 // attributes can be used to change dialect-specific global configurations via
2038 // `amendOperation()`. These configurations can then influence the translation
2039 // of operations afterwards.
2040 if (failed(translator
.convertOperation(*module
, llvmBuilder
)))
2043 if (failed(translator
.convertComdats()))
2045 if (failed(translator
.convertFunctionSignatures()))
2047 if (failed(translator
.convertGlobals()))
2049 if (failed(translator
.createTBAAMetadata()))
2051 if (failed(translator
.createIdentMetadata()))
2053 if (failed(translator
.createCommandlineMetadata()))
2056 // Convert other top-level operations if possible.
2057 for (Operation
&o
: getModuleBody(module
).getOperations()) {
2058 if (!isa
<LLVM::LLVMFuncOp
, LLVM::GlobalOp
, LLVM::GlobalCtorsOp
,
2059 LLVM::GlobalDtorsOp
, LLVM::ComdatOp
>(&o
) &&
2060 !o
.hasTrait
<OpTrait::IsTerminator
>() &&
2061 failed(translator
.convertOperation(o
, llvmBuilder
))) {
2066 // Operations in function bodies with symbolic references must be converted
2067 // after the top-level operations they refer to are declared, so we do it
2069 if (failed(translator
.convertFunctions()))
2072 // Once we've finished constructing elements in the module, we should convert
2073 // it to use the debug info format desired by LLVM.
2074 // See https://llvm.org/docs/RemoveDIsDebugInfo.html
2075 translator
.llvmModule
->setIsNewDbgInfoFormat(UseNewDbgInfoFormat
);
2077 if (!disableVerification
&&
2078 llvm::verifyModule(*translator
.llvmModule
, &llvm::errs()))
2081 return std::move(translator
.llvmModule
);