1 //===- MlirTranslateMain.cpp - MLIR Translation entry point ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
10 #include "mlir/IR/AsmState.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Support/FileUtilities.h"
16 #include "mlir/Support/Timing.h"
17 #include "mlir/Support/ToolUtilities.h"
18 #include "mlir/Tools/mlir-translate/Translation.h"
19 #include "llvm/Support/InitLLVM.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Support/ToolOutputFile.h"
25 //===----------------------------------------------------------------------===//
27 //===----------------------------------------------------------------------===//
30 /// A scoped diagnostic handler that marks non-error diagnostics as handled. As
31 /// a result, the main diagnostic handler does not print non-error diagnostics.
32 class ErrorDiagnosticFilter
: public ScopedDiagnosticHandler
{
34 ErrorDiagnosticFilter(MLIRContext
*ctx
) : ScopedDiagnosticHandler(ctx
) {
35 setHandler([](Diagnostic
&diag
) {
36 if (diag
.getSeverity() != DiagnosticSeverity::Error
)
44 //===----------------------------------------------------------------------===//
45 // Translate Entry Point
46 //===----------------------------------------------------------------------===//
48 LogicalResult
mlir::mlirTranslateMain(int argc
, char **argv
,
49 llvm::StringRef toolName
) {
51 static llvm::cl::opt
<std::string
> inputFilename(
52 llvm::cl::Positional
, llvm::cl::desc("<input file>"),
55 static llvm::cl::opt
<std::string
> outputFilename(
56 "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
59 static llvm::cl::opt
<bool> allowUnregisteredDialects(
60 "allow-unregistered-dialect",
61 llvm::cl::desc("Allow operation with no registered dialects (discouraged: testing only!)"),
62 llvm::cl::init(false));
64 static llvm::cl::opt
<std::string
> inputSplitMarker
{
65 "split-input-file", llvm::cl::ValueOptional
,
66 llvm::cl::callback([&](const std::string
&str
) {
67 // Implicit value: use default marker if flag was used without value.
69 inputSplitMarker
.setValue(kDefaultSplitMarker
);
71 llvm::cl::desc("Split the input file into chunks using the given or "
72 "default marker and process each chunk independently"),
75 static llvm::cl::opt
<bool> verifyDiagnostics(
77 llvm::cl::desc("Check that emitted diagnostics match "
78 "expected-* lines on the corresponding line"),
79 llvm::cl::init(false));
81 static llvm::cl::opt
<bool> errorDiagnosticsOnly(
82 "error-diagnostics-only",
83 llvm::cl::desc("Filter all non-error diagnostics "
84 "(discouraged: testing only!)"),
85 llvm::cl::init(false));
87 static llvm::cl::opt
<std::string
> outputSplitMarker(
88 "output-split-marker",
89 llvm::cl::desc("Split marker to use for merging the ouput"),
92 llvm::InitLLVM
y(argc
, argv
);
94 // Add flags for all the registered translations.
95 llvm::cl::list
<const Translation
*, bool, TranslationParser
>
96 translationsRequested("", llvm::cl::desc("Translations to perform"),
98 registerAsmPrinterCLOptions();
99 registerMLIRContextCLOptions();
100 registerTranslationCLOptions();
101 registerDefaultTimingManagerCLOptions();
102 llvm::cl::ParseCommandLineOptions(argc
, argv
, toolName
);
104 // Initialize the timing manager.
105 DefaultTimingManager tm
;
106 applyDefaultTimingManagerCLOptions(tm
);
107 TimingScope timing
= tm
.getRootScope();
109 std::string errorMessage
;
110 std::unique_ptr
<llvm::MemoryBuffer
> input
;
111 if (auto inputAlignment
= translationsRequested
[0]->getInputAlignment())
112 input
= openInputFile(inputFilename
, *inputAlignment
, &errorMessage
);
114 input
= openInputFile(inputFilename
, &errorMessage
);
116 llvm::errs() << errorMessage
<< "\n";
120 auto output
= openOutputFile(outputFilename
, &errorMessage
);
122 llvm::errs() << errorMessage
<< "\n";
126 // Processes the memory buffer with a new MLIRContext.
127 auto processBuffer
= [&](std::unique_ptr
<llvm::MemoryBuffer
> ownedBuffer
,
129 // Temporary buffers for chained translation processing.
132 LogicalResult result
= LogicalResult::success();
134 for (size_t i
= 0, e
= translationsRequested
.size(); i
< e
; ++i
) {
135 llvm::raw_ostream
*stream
;
136 llvm::raw_string_ostream
dataStream(dataOut
);
139 // Output last translation to output.
142 // Output translation to temporary data buffer.
143 stream
= &dataStream
;
146 const Translation
*translationRequested
= translationsRequested
[i
];
147 TimingScope translationTiming
=
148 timing
.nest(translationRequested
->getDescription());
151 context
.allowUnregisteredDialects(allowUnregisteredDialects
);
152 context
.printOpOnDiagnostic(!verifyDiagnostics
);
153 auto sourceMgr
= std::make_shared
<llvm::SourceMgr
>();
154 sourceMgr
->AddNewSourceBuffer(std::move(ownedBuffer
), SMLoc());
156 if (verifyDiagnostics
) {
157 // In the diagnostic verification flow, we ignore whether the
158 // translation failed (in most cases, it is expected to fail) and we do
159 // not filter non-error diagnostics even if `errorDiagnosticsOnly` is
160 // set. Instead, we check if the diagnostics were produced as expected.
161 SourceMgrDiagnosticVerifierHandler
sourceMgrHandler(*sourceMgr
,
163 (void)(*translationRequested
)(sourceMgr
, os
, &context
);
164 result
= sourceMgrHandler
.verify();
165 } else if (errorDiagnosticsOnly
) {
166 SourceMgrDiagnosticHandler
sourceMgrHandler(*sourceMgr
, &context
);
167 ErrorDiagnosticFilter
diagnosticFilter(&context
);
168 result
= (*translationRequested
)(sourceMgr
, *stream
, &context
);
170 SourceMgrDiagnosticHandler
sourceMgrHandler(*sourceMgr
, &context
);
171 result
= (*translationRequested
)(sourceMgr
, *stream
, &context
);
177 // If there are further translations, create a new buffer with the
181 ownedBuffer
= llvm::MemoryBuffer::getMemBuffer(dataIn
);
187 if (failed(splitAndProcessBuffer(std::move(input
), processBuffer
,
188 output
->os(), inputSplitMarker
,