Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / ConvertToLLVM / ConvertToLLVMPass.cpp
blobb2407a258c2719420e00e5b60868c63251e6f3a6
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
10 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include <memory>
20 #define DEBUG_TYPE "convert-to-llvm"
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
27 using namespace mlir;
29 namespace {
31 /// This DialectExtension can be attached to the context, which will invoke the
32 /// `apply()` method for every loaded dialect. If a dialect implements the
33 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
34 /// through the interface. This extension is loaded in the context before
35 /// starting a pass pipeline that involves dialect conversion to LLVM.
36 class LoadDependentDialectExtension : public DialectExtensionBase {
37 public:
38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
40 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
42 void apply(MLIRContext *context,
43 MutableArrayRef<Dialect *> dialects) const final {
44 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
45 for (Dialect *dialect : dialects) {
46 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
47 if (!iface)
48 continue;
49 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
50 << dialect->getNamespace() << "\n");
51 iface->loadDependentDialects(context);
55 /// Return a copy of this extension.
56 std::unique_ptr<DialectExtensionBase> clone() const final {
57 return std::make_unique<LoadDependentDialectExtension>(*this);
61 /// This is a generic pass to convert to LLVM, it uses the
62 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
63 /// the injection of conversion patterns.
64 class ConvertToLLVMPass
65 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
66 std::shared_ptr<const FrozenRewritePatternSet> patterns;
67 std::shared_ptr<const ConversionTarget> target;
68 std::shared_ptr<const LLVMTypeConverter> typeConverter;
70 public:
71 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
72 void getDependentDialects(DialectRegistry &registry) const final {
73 registry.insert<LLVM::LLVMDialect>();
74 registry.addExtensions<LoadDependentDialectExtension>();
77 LogicalResult initialize(MLIRContext *context) final {
78 RewritePatternSet tempPatterns(context);
79 auto target = std::make_shared<ConversionTarget>(*context);
80 target->addLegalDialect<LLVM::LLVMDialect>();
81 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
83 if (!filterDialects.empty()) {
84 // Test mode: Populate only patterns from the specified dialects. Produce
85 // an error if the dialect is not loaded or does not implement the
86 // interface.
87 for (std::string &dialectName : filterDialects) {
88 Dialect *dialect = context->getLoadedDialect(dialectName);
89 if (!dialect)
90 return emitError(UnknownLoc::get(context))
91 << "dialect not loaded: " << dialectName << "\n";
92 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
93 if (!iface)
94 return emitError(UnknownLoc::get(context))
95 << "dialect does not implement ConvertToLLVMPatternInterface: "
96 << dialectName << "\n";
97 iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
98 tempPatterns);
100 } else {
101 // Normal mode: Populate all patterns from all dialects that implement the
102 // interface.
103 for (Dialect *dialect : context->getLoadedDialects()) {
104 // First time we encounter this dialect: if it implements the interface,
105 // let's populate patterns !
106 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
107 if (!iface)
108 continue;
109 iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
110 tempPatterns);
114 this->patterns =
115 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
116 this->target = target;
117 this->typeConverter = typeConverter;
118 return success();
121 void runOnOperation() final {
122 if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
123 signalPassFailure();
127 } // namespace
129 void mlir::registerConvertToLLVMDependentDialectLoading(
130 DialectRegistry &registry) {
131 registry.addExtensions<LoadDependentDialectExtension>();
134 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
135 return std::make_unique<ConvertToLLVMPass>();