[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / docs / Tutorials / QuickstartRewrites.md
blob604148bd9c600783ba1b2544c9ca3edc5bad8892
1 # Quickstart tutorial to adding MLIR graph rewrite
3 This document will present a quickstart to adding graph rewrites. We shall start
4 by defining an operation, showing multiple ways to define the rewrite using
5 patterns, as well as defining the rewrite using a graph walker (note: using
6 patterns and the rewrite engine is preferred, showing the walker is for
7 demonstration purposes).
9 See [MLIR specification](../LangRef.md) for more information about MLIR, the
10 structure of the IR, operations, etc. See
11 [Table-driven Operation Definition](../DefiningDialects/Operations.md) and
12 [Declarative Rewrite Rule](../DeclarativeRewrites.md) for the detailed explanation
13 of all available mechanisms for defining operations and rewrites in a
14 table-driven manner.
16 ## Adding operation
18 An operation in MLIR is specified using a definition in
19 [TableGen](https://llvm.org/docs/TableGen/index.html) file. TableGen is a
20 modeling tool to specify the ops and the C++ code to interact with these
21 operations are generated from. To define an operation one needs to specify:
23 *   The operation name. This name is a unique identifier of the operation within
24     MLIR. Most operations are within a dialect, so for example one could have
25     `tfl.add` to represent the add operation in the TensorFlow Lite dialect.
26     Instead of repeating the dialect in the op definition, a base class for the
27     op dialect is commonly created that prepends the dialect namespace given an
28     op name.
29 *   The traits of the operation. These allow you to specify traits of the
30     operation, such as whether it has side effects or whether it should be
31     verified that the operands and result types are the same. These are backed
32     by C++ traits that perform the verification.
33 *   The arguments of the operation. These are the input operands (values at
34     runtime produced by other ops) and attributes (compile time known constant
35     values that affect the behavior of the op) that are the inputs of/define the
36     behavior of the operation. The input operands may be named, the attributes
37     must be named.
38 *   The result(s) of the operation. These may again named or not.
39 *   Documentation of the operation. This includes a one-line summary as well as
40     a longer human-readable description of the operation.
41 *   Dialect specific information. Additional information could be added to the
42     operation definition that are only used by dialect specific drivers. These
43     are ignored by the main op and doc generators, but could be used in, say,
44     the translation from a dialect to another representation.
46 ```tablegen
47 def TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu",
48                             [NoMemoryEffect, SameValueType]>,
49                      Results<(outs Tensor)> {
50   let arguments = (ins
51     F32Tensor:$x,
52     // Slope of the activation function at x < 0.
53     F32Attr:$alpha
54   );
56   let summary = "Leaky ReLU operator";
57   let description = [{
58     Element-wise Leaky ReLU operator
59       x -> x >= 0 ? x : (alpha * x)
60   }];
62   // TFLite specific attribute that is used when generating the output
63   // flatbuffer.
64   let hasOptions = 1;
66 ```
68 Note in the above the result types and inputs are specified in different ways,
69 one by way of trait and the other by way of let. It is possible to specify both
70 in either way.
72 <!-- TODO: Define a style convention. -->
74 Operations can also have custom parser, printer, builder, verifier, constant
75 folder, or canonicalizer. These require specifying additional C++ methods to
76 invoke for additional functionality. For example, if an operation is marked to
77 have a folder, the constant folder also needs to be added, e.g.,:
79 ```c++
80 OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
81   if (unable_to_fold)
82     return {};
83   ....
84   return val;
86 ```
88 ## Adding patterns
90 There are multiple forms of graph rewrite that can be performed in MLIR. One of
91 the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way
92 to express this transformation as a pair of source pattern to match and
93 resultant pattern. There are both the C++ classes to represent this
94 transformation, as well as the patterns in TableGen from which these can be
95 generated.
97 ### TableGen patterns
99 Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to
100 TensorFlow Lite's `LeakyRelu`:
102 ```tablegen
103 def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>
106 The pattern is specified by instantiating a `Pat` with a source and result DAG.
107 The arguments in the source pattern is captured and can be used in the result
108 pattern. This is a simple pattern as we have a 1:1 mapping and the attribute
109 does not need to be transformed (e.g., both have a floating point attribute for
110 alpha). The names of the attributes specified in the pattern is for
111 matching/referencing and need not match the original attribute name in the op
112 definition but the order of arguments of the dags do need to match.
114 To specify a pattern, both the source and resultant ops need to be defined using
115 TableGen.
117 If this were a more advance pattern that the current framework could not express
118 as destination then one could use a general native code fallback method. This
119 consists of defining a pattern as well as adding a C++ function to perform the
120 replacement:
122 ```tablegen
123 def createTFLLeakyRelu : NativeCodeCall<
124     "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
126 def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
127           (createTFLLeakyRelu $old_value, $arg, $a)>;
130 ```c++
131 static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
132                                 Value operand, Attribute attr) {
133   return rewriter.create<mlir::TFL::LeakyReluOp>(
134       op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
135       /*alpha=*/attrs[0].cast<FloatAttr>());
139 This allows for arbitrarily complex builders. Input pattern side one can express
140 multi-op patterns with constraints on input operands and attributes. But input
141 patterns cannot yet express constraints across multiple operands/attributes.
143 ### Register the pattern
145 The file containing the patterns need to be processed using `mlir-tblgen`
146 `-gen-rewriters` during compilation time. It can be invoked with the following
147 configuration in CMake:
149 ```cmake
150 set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
151 mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
152 add_public_tablegen_target(<name-of-the-cmake-target>)
155 Then you can `#include` the generated file in any C++ implementation file you
156 like. (You will also need to make sure the library depends on the CMake target
157 defined in the above.) The generated file will have a `populateWithGenerated(
158 RewritePatternSet &patterns)` function that you can
159 use to collect all the generated patterns inside `patterns` and then use
160 `patterns` in any pass you would like.
162 ### Simple C++ `matchAndRewrite` style specifications
164 Many simple rewrites can be expressed with a `matchAndRewrite` style  of
165 pattern, e.g. when converting a multiply by a power of two into a shift.  For
166 these cases, the you can define the pattern as a simple function:
168 ```c++
169 static LogicalResult
170 convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
171   rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
172       op, op->getResult(0).getType(), op->getOperand(0),
173       /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
174   return success();
177 void populateRewrites(RewritePatternSet &patternSet) {
178   // Add it to a pattern set.
179   patternSet.add(convertTFLeakyRelu);
183 ODS provides a simple way to define a function-style canonicalization for your
184 operation.  In the TableGen definition of the op, specify
185 `let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in
186 your .cpp file:
188 ```c++
189 // Example from the CIRCT project which has a variadic integer multiply.
190 LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
191   auto inputs = op.inputs();
192   APInt value;
194   // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
195   if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) &&
196       value.isPowerOf2()) {
197     auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),
198                                                   value.exactLogBase2());
199     auto shlOp =
200         rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
201     rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(),
202                                        ArrayRef<Value>(shlOp));
203     return success();
204   }
206   return failure();
210 However, you may want the full generality of canonicalization patterns, for that
211 you can specify an arbitrary list of `RewritePattern`s.
213 ### Fully general C++ `RewritePattern` specifications
215 In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
216 you can also specify rewrites as a general set of `RewritePattern`s:
218 ```c++
219 /// Multi-step rewrite using "match" and "rewrite". This allows for separating
220 /// the concerns of matching and rewriting.
221 struct ConvertTFLeakyRelu : public RewritePattern {
222   ConvertTFLeakyRelu(MLIRContext *context)
223       : RewritePattern("tf.LeakyRelu", 1, context) {}
225   LogicalResult match(Operation *op) const override {
226     return success();
227   }
229   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
230     rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
231         op, op->getResult(0).getType(), op->getOperand(0),
232         /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
233   }
236 /// Single-step rewrite with "matchAndRewrite". This allows for performing the
237 /// rewrite immediately upon a successful match.
238 struct ConvertTFLeakyRelu : public RewritePattern {
239   ConvertTFLeakyRelu(MLIRContext *context)
240       : RewritePattern("tf.LeakyRelu", 1, context) {}
242   LogicalResult matchAndRewrite(Operation *op,
243                                 PatternRewriter &rewriter) const override {
244     rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
245         op, op->getResult(0).getType(), op->getOperand(0),
246         /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
247     return success();
248   }
252 In the C++ rewrite the static benefit of the rewrite pattern is specified at
253 construction. While in the pattern generator a simple heuristic is currently
254 employed based around the number of ops matched and replaced.
256 The above rule did not capture the matching operands/attributes, but in general
257 the `match` function in a multi-step rewrite may populate and return a
258 `PatternState` (or class derived from one) to pass information extracted during
259 matching to the rewrite. A single-step rewrite with the `matchAndRewrite`
260 function has the benefit of being able to directly use any values created when
261 matching; removing the need for `PatternState`.
263 ## Testing
265 MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated
266 Testing) tool for performing testing. Testing is performed by way of creating
267 the input IR file, running a transformation and then verifying the output IR.
268 C++ unit tests are the exception, with the IR transformation serving as the core
269 testing mechanism. This results in fewer binaries that need to be built (and
270 linked) and forces to focus on the representation as an important piece.
272 For the legalization transform above we would have a test (probably as part of
273 the legalization pass test in TensorFlow Lite) such as:
275 ```mlir
276 // RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s
278 func.func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
279   %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32>
280   return %2: tensor<1xf32>
282 // CHECK-LABEL: LeakyRelu
283 // CHECK:  %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32>
287 The RUN command at the top results in running the `mlir-opt` binary (which is
288 compiler writer tool to exercise different registered passes) to invoke the
289 optimization pass this transform was added as part of on the current file and to
290 verify its output using `FileCheck`. `FileCheck` is textual output verifier. In
291 particular it uses the CHECK expressions to verify the given output is produced.
293 There can be multiple RUN commands with different corresponding CHECK prefixes.
294 And in addition multiple independent tests separated by `// -----` and
295 `mlir-opt` invoked with `-split-input-file` flag. This is especially useful for
296 error testing.
298 This results in very simple, directed testing without need to work around
299 constant propagation or other, unrelated, optimization passes.
301 ## Adding optimization pass
303 Optimization passes that do not fit/are difficult to specify in the above
304 structure can be specified as general iterations across modules/functions. See
305 [Writing a Pass](../PassManagement.md) for a general overview and introduction to
306 optimization passes in MLIR.