1 //===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
10 #include "mlir/Dialect/Transform/IR/Utils.h"
11 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/DialectRegistry.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/InitAllDialects.h"
18 #include "mlir/InitAllExtensions.h"
19 #include "mlir/InitAllPasses.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/InitLLVM.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/Support/ToolOutputFile.h"
33 /// Structure containing command line options for the tool, these will get
34 /// initialized when an instance is created.
35 struct MlirTransformOptCLOptions
{
36 cl::opt
<bool> allowUnregisteredDialects
{
37 "allow-unregistered-dialect",
38 cl::desc("Allow operations coming from an unregistered dialect"),
41 cl::opt
<bool> verifyDiagnostics
{
43 cl::desc("Check that emitted diagnostics match expected-* lines "
44 "on the corresponding line"),
47 cl::opt
<std::string
> payloadFilename
{cl::Positional
, cl::desc("<input file>"),
50 cl::opt
<std::string
> outputFilename
{"o", cl::desc("Output filename"),
51 cl::value_desc("filename"),
54 cl::opt
<std::string
> transformMainFilename
{
56 cl::desc("File containing entry point of the transform script, if "
57 "different from the input file"),
58 cl::value_desc("filename"), cl::init("")};
60 cl::list
<std::string
> transformLibraryFilenames
{
61 "transform-library", cl::desc("File(s) containing definitions of "
62 "additional transform script symbols")};
64 cl::opt
<std::string
> transformEntryPoint
{
65 "transform-entry-point",
66 cl::desc("Name of the entry point transform symbol"),
67 cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
70 cl::opt
<bool> disableExpensiveChecks
{
71 "disable-expensive-checks",
72 cl::desc("Disables potentially expensive checks in the transform "
73 "interpreter, providing more speed at the expense of "
74 "potential memory problems and silent corruptions"),
77 cl::opt
<bool> dumpLibraryModule
{
78 "dump-library-module",
79 cl::desc("Prints the combined library module before the output"),
84 /// "Managed" static instance of the command-line options structure. This makes
85 /// them locally-scoped and explicitly initialized/deinitialized. While this is
86 /// not strictly necessary in the tool source file that is not being used as a
87 /// library (where the options would pollute the global list of options), it is
88 /// good practice to follow this.
89 static llvm::ManagedStatic
<MlirTransformOptCLOptions
> clOptions
;
91 /// Explicitly registers command-line options.
92 static void registerCLOptions() { *clOptions
; }
95 /// A wrapper class for source managers diagnostic. This provides both unique
96 /// ownership and virtual function-like overload for a pair of
97 /// inheritance-related classes that do not use virtual functions.
98 class DiagnosticHandlerWrapper
{
100 /// Kind of the diagnostic handler to use.
101 enum class Kind
{ EmitDiagnostics
, VerifyDiagnostics
};
103 /// Constructs the diagnostic handler of the specified kind of the given
104 /// source manager and context.
105 DiagnosticHandlerWrapper(Kind kind
, llvm::SourceMgr
&mgr
,
106 mlir::MLIRContext
*context
) {
107 if (kind
== Kind::EmitDiagnostics
)
108 handler
= new mlir::SourceMgrDiagnosticHandler(mgr
, context
);
110 handler
= new mlir::SourceMgrDiagnosticVerifierHandler(mgr
, context
);
113 /// This object is non-copyable but movable.
114 DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper
&) = delete;
115 DiagnosticHandlerWrapper(DiagnosticHandlerWrapper
&&other
) = default;
116 DiagnosticHandlerWrapper
&
117 operator=(const DiagnosticHandlerWrapper
&) = delete;
118 DiagnosticHandlerWrapper
&operator=(DiagnosticHandlerWrapper
&&) = default;
120 /// Verifies the captured "expected-*" diagnostics if required.
121 llvm::LogicalResult
verify() const {
123 handler
.dyn_cast
<mlir::SourceMgrDiagnosticVerifierHandler
*>()) {
124 return ptr
->verify();
126 return mlir::success();
129 /// Destructs the object of the same type as allocated.
130 ~DiagnosticHandlerWrapper() {
131 if (auto *ptr
= handler
.dyn_cast
<mlir::SourceMgrDiagnosticHandler
*>()) {
134 delete handler
.get
<mlir::SourceMgrDiagnosticVerifierHandler
*>();
139 /// Internal storage is a type-safe union.
140 llvm::PointerUnion
<mlir::SourceMgrDiagnosticHandler
*,
141 mlir::SourceMgrDiagnosticVerifierHandler
*>
145 /// MLIR has deeply rooted expectations that the LLVM source manager contains
146 /// exactly one buffer, until at least the lexer level. This class wraps
147 /// multiple LLVM source managers each managing a buffer to match MLIR's
148 /// expectations while still providing a centralized handling mechanism.
149 class TransformSourceMgr
{
151 /// Constructs the source manager indicating whether diagnostic messages will
152 /// be verified later on.
153 explicit TransformSourceMgr(bool verifyDiagnostics
)
154 : verifyDiagnostics(verifyDiagnostics
) {}
156 /// Deconstructs the source manager. Note that `checkResults` must have been
157 /// called on this instance before deconstructing it.
158 ~TransformSourceMgr() {
159 assert(resultChecked
&& "must check the result of diagnostic handlers by "
160 "running TransformSourceMgr::checkResult");
163 /// Parses the given buffer and creates the top-level operation of the kind
164 /// specified as template argument in the given context. Additional parsing
165 /// options may be provided.
166 template <typename OpTy
= mlir::Operation
*>
167 mlir::OwningOpRef
<OpTy
> parseBuffer(std::unique_ptr
<MemoryBuffer
> buffer
,
168 mlir::MLIRContext
&context
,
169 const mlir::ParserConfig
&config
) {
170 // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
171 // the code below to capture a reference to the source manager in such a way
172 // that it is not invalidated when the vector contents is eventually
174 llvm::SourceMgr
&mgr
=
175 *sourceMgrs
.emplace_back(std::make_unique
<llvm::SourceMgr
>());
176 mgr
.AddNewSourceBuffer(std::move(buffer
), llvm::SMLoc());
178 // Choose the type of diagnostic handler depending on whether diagnostic
179 // verification needs to happen and store it.
180 if (verifyDiagnostics
) {
181 diagHandlers
.emplace_back(
182 DiagnosticHandlerWrapper::Kind::VerifyDiagnostics
, mgr
, &context
);
184 diagHandlers
.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics
,
188 // Defer to MLIR's parser.
189 return mlir::parseSourceFile
<OpTy
>(mgr
, config
);
192 /// If diagnostic message verification has been requested upon construction of
193 /// this source manager, performs the verification, reports errors and returns
194 /// the result of the verification. Otherwise passes through the given value.
195 llvm::LogicalResult
checkResult(llvm::LogicalResult result
) {
196 resultChecked
= true;
197 if (!verifyDiagnostics
)
200 return mlir::failure(llvm::any_of(diagHandlers
, [](const auto &handler
) {
201 return mlir::failed(handler
.verify());
206 /// Indicates whether diagnostic message verification is requested.
207 const bool verifyDiagnostics
;
209 /// Indicates that diagnostic message verification has taken place, and the
210 /// deconstruction is therefore safe.
211 bool resultChecked
= false;
213 /// Storage for per-buffer source managers and diagnostic handlers. These are
214 /// wrapped into unique pointers in order to make it safe to capture
215 /// references to these objects: if the vector is reallocated, the unique
216 /// pointer objects are moved by the pointer addresses won't change. Also, for
217 /// handlers, this allows to store the pointer to the base class.
218 SmallVector
<std::unique_ptr
<llvm::SourceMgr
>> sourceMgrs
;
219 SmallVector
<DiagnosticHandlerWrapper
> diagHandlers
;
223 /// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
224 /// and doesn't enforce the entry point transform ops being top-level.
225 static llvm::LogicalResult
226 applyTransforms(mlir::Operation
*payloadRoot
,
227 mlir::transform::TransformOpInterface transformRoot
,
228 const mlir::transform::TransformOptions
&options
) {
229 return applyTransforms(payloadRoot
, transformRoot
, {}, options
,
230 /*enforceToplevelTransformOp=*/false);
233 /// Applies transforms indicated in the transform dialect script to the input
234 /// buffer. The transform script may be embedded in the input buffer or as a
235 /// separate buffer. The transform script may have external symbols, the
236 /// definitions of which must be provided in transform library buffers. If the
237 /// application is successful, prints the transformed input buffer into the
238 /// given output stream. Additional configuration options are derived from
239 /// command-line options.
240 static llvm::LogicalResult
processPayloadBuffer(
241 raw_ostream
&os
, std::unique_ptr
<MemoryBuffer
> inputBuffer
,
242 std::unique_ptr
<llvm::MemoryBuffer
> transformBuffer
,
243 MutableArrayRef
<std::unique_ptr
<MemoryBuffer
>> transformLibraries
,
244 mlir::DialectRegistry
®istry
) {
246 // Initialize the MLIR context, and various configurations.
247 mlir::MLIRContext
context(registry
, mlir::MLIRContext::Threading::DISABLED
);
248 context
.allowUnregisteredDialects(clOptions
->allowUnregisteredDialects
);
249 mlir::ParserConfig
config(&context
);
250 TransformSourceMgr
sourceMgr(
251 /*verifyDiagnostics=*/clOptions
->verifyDiagnostics
);
253 // Parse the input buffer that will be used as transform payload.
254 mlir::OwningOpRef
<mlir::Operation
*> payloadRoot
=
255 sourceMgr
.parseBuffer(std::move(inputBuffer
), context
, config
);
257 return sourceMgr
.checkResult(mlir::failure());
259 // Identify the module containing the transform script entry point. This may
260 // be the same module as the input or a separate module. In the former case,
261 // make a copy of the module so it can be modified freely. Modification may
262 // happen in the script itself (at which point it could be rewriting itself
263 // during interpretation, leading to tricky memory errors) or by embedding
264 // library modules in the script.
265 mlir::OwningOpRef
<mlir::ModuleOp
> transformRoot
;
266 if (transformBuffer
) {
267 transformRoot
= sourceMgr
.parseBuffer
<mlir::ModuleOp
>(
268 std::move(transformBuffer
), context
, config
);
270 return sourceMgr
.checkResult(mlir::failure());
272 transformRoot
= cast
<mlir::ModuleOp
>(payloadRoot
->clone());
275 // Parse and merge the libraries into the main transform module.
276 for (auto &&transformLibrary
: transformLibraries
) {
277 mlir::OwningOpRef
<mlir::ModuleOp
> libraryModule
=
278 sourceMgr
.parseBuffer
<mlir::ModuleOp
>(std::move(transformLibrary
),
281 if (!libraryModule
||
282 mlir::failed(mlir::transform::detail::mergeSymbolsInto(
283 *transformRoot
, std::move(libraryModule
))))
284 return sourceMgr
.checkResult(mlir::failure());
287 // If requested, dump the combined transform module.
288 if (clOptions
->dumpLibraryModule
)
289 transformRoot
->dump();
291 // Find the entry point symbol. Even if it had originally been in the payload
292 // module, it was cloned into the transform module so only look there.
293 mlir::transform::TransformOpInterface entryPoint
=
294 mlir::transform::detail::findTransformEntryPoint(
295 *transformRoot
, mlir::ModuleOp(), clOptions
->transformEntryPoint
);
297 return sourceMgr
.checkResult(mlir::failure());
299 // Apply the requested transformations.
300 mlir::transform::TransformOptions transformOptions
;
301 transformOptions
.enableExpensiveChecks(!clOptions
->disableExpensiveChecks
);
302 if (mlir::failed(applyTransforms(*payloadRoot
, entryPoint
, transformOptions
)))
303 return sourceMgr
.checkResult(mlir::failure());
305 // Print the transformed result and check the captured diagnostics if
307 payloadRoot
->print(os
);
308 return sourceMgr
.checkResult(mlir::success());
311 /// Tool entry point.
312 static llvm::LogicalResult
runMain(int argc
, char **argv
) {
313 // Register all upstream dialects and extensions. Specific uses are advised
314 // not to register all dialects indiscriminately but rather hand-pick what is
315 // necessary for their use case.
316 mlir::DialectRegistry registry
;
317 mlir::registerAllDialects(registry
);
318 mlir::registerAllExtensions(registry
);
319 mlir::registerAllPasses();
321 // Explicitly register the transform dialect. This is not strictly necessary
322 // since it has been already registered as part of the upstream dialect list,
323 // but useful for example purposes for cases when dialects to register are
324 // hand-picked. The transform dialect must be registered.
325 registry
.insert
<mlir::transform::TransformDialect
>();
327 // Register various command-line options. Note that the LLVM initializer
328 // object is a RAII that ensures correct deconstruction of command-line option
329 // objects inside ManagedStatic.
330 llvm::InitLLVM
y(argc
, argv
);
331 mlir::registerAsmPrinterCLOptions();
332 mlir::registerMLIRContextCLOptions();
334 llvm::cl::ParseCommandLineOptions(argc
, argv
,
335 "Minimal Transform dialect driver\n");
337 // Try opening the main input file.
338 std::string errorMessage
;
339 std::unique_ptr
<llvm::MemoryBuffer
> payloadFile
=
340 mlir::openInputFile(clOptions
->payloadFilename
, &errorMessage
);
342 llvm::errs() << errorMessage
<< "\n";
343 return mlir::failure();
346 // Try opening the output file.
347 std::unique_ptr
<llvm::ToolOutputFile
> outputFile
=
348 mlir::openOutputFile(clOptions
->outputFilename
, &errorMessage
);
350 llvm::errs() << errorMessage
<< "\n";
351 return mlir::failure();
354 // Try opening the main transform file if provided.
355 std::unique_ptr
<llvm::MemoryBuffer
> transformRootFile
;
356 if (!clOptions
->transformMainFilename
.empty()) {
357 if (clOptions
->transformMainFilename
== clOptions
->payloadFilename
) {
358 llvm::errs() << "warning: " << clOptions
->payloadFilename
359 << " is provided as both payload and transform file\n";
362 mlir::openInputFile(clOptions
->transformMainFilename
, &errorMessage
);
363 if (!transformRootFile
) {
364 llvm::errs() << errorMessage
<< "\n";
365 return mlir::failure();
370 // Try opening transform library files if provided.
371 SmallVector
<std::unique_ptr
<llvm::MemoryBuffer
>> transformLibraries
;
372 transformLibraries
.reserve(clOptions
->transformLibraryFilenames
.size());
373 for (llvm::StringRef filename
: clOptions
->transformLibraryFilenames
) {
374 transformLibraries
.emplace_back(
375 mlir::openInputFile(filename
, &errorMessage
));
376 if (!transformLibraries
.back()) {
377 llvm::errs() << errorMessage
<< "\n";
378 return mlir::failure();
382 return processPayloadBuffer(outputFile
->os(), std::move(payloadFile
),
383 std::move(transformRootFile
), transformLibraries
,
387 int main(int argc
, char **argv
) {
388 return mlir::asMainReturnCode(runMain(argc
, argv
));