1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10 // The byte-code is constructed from the PDL Interpreter dialect.
12 //===----------------------------------------------------------------------===//
14 #ifndef MLIR_REWRITE_BYTECODE_H_
15 #define MLIR_REWRITE_BYTECODE_H_
17 #include "mlir/IR/PatternMatch.h"
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
22 namespace pdl_interp
{
24 } // namespace pdl_interp
29 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
30 /// entries. ByteCodeAddr refers to size of indices into the bytecode.
31 using ByteCodeField
= uint16_t;
32 using ByteCodeAddr
= uint32_t;
33 using OwningOpRange
= llvm::OwningArrayRef
<Operation
*>;
35 //===----------------------------------------------------------------------===//
37 //===----------------------------------------------------------------------===//
39 /// All of the data pertaining to a specific pattern within the bytecode.
40 class PDLByteCodePattern
: public Pattern
{
42 static PDLByteCodePattern
create(pdl_interp::RecordMatchOp matchOp
,
43 PDLPatternConfigSet
*configSet
,
44 ByteCodeAddr rewriterAddr
);
46 /// Return the bytecode address of the rewriter for this pattern.
47 ByteCodeAddr
getRewriterAddr() const { return rewriterAddr
; }
49 /// Return the configuration set for this pattern, or null if there is none.
50 PDLPatternConfigSet
*getConfigSet() const { return configSet
; }
53 template <typename
... Args
>
54 PDLByteCodePattern(ByteCodeAddr rewriterAddr
, PDLPatternConfigSet
*configSet
,
55 Args
&&...patternArgs
)
56 : Pattern(std::forward
<Args
>(patternArgs
)...), rewriterAddr(rewriterAddr
),
57 configSet(configSet
) {}
59 /// The address of the rewriter for this pattern.
60 ByteCodeAddr rewriterAddr
;
62 /// The optional config set for this pattern.
63 PDLPatternConfigSet
*configSet
;
66 //===----------------------------------------------------------------------===//
67 // PDLByteCodeMutableState
68 //===----------------------------------------------------------------------===//
70 /// This class contains the mutable state of a bytecode instance. This allows
71 /// for a bytecode instance to be cached and reused across various different
73 class PDLByteCodeMutableState
{
75 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
76 /// to the position of the pattern within the range returned by
77 /// `PDLByteCode::getPatterns`.
78 void updatePatternBenefit(unsigned patternIndex
, PatternBenefit benefit
);
80 /// Cleanup any allocated state after a match/rewrite has been completed. This
81 /// method should be called irregardless of whether the match+rewrite was a
83 void cleanupAfterMatchAndRewrite();
86 /// Allow access to data fields.
87 friend class PDLByteCode
;
89 /// The mutable block of memory used during the matching and rewriting phases
91 std::vector
<const void *> memory
;
93 /// A mutable block of memory used during the matching and rewriting phase of
94 /// the bytecode to store ranges of operations. These are always stored by
95 /// owning references, because at no point in the execution of the byte code
96 /// we get an indexed range (view) of operations.
97 std::vector
<OwningOpRange
> opRangeMemory
;
99 /// A mutable block of memory used during the matching and rewriting phase of
100 /// the bytecode to store ranges of types.
101 std::vector
<TypeRange
> typeRangeMemory
;
102 /// A set of type ranges that have been allocated by the byte code interpreter
103 /// to provide a guaranteed lifetime.
104 std::vector
<llvm::OwningArrayRef
<Type
>> allocatedTypeRangeMemory
;
106 /// A mutable block of memory used during the matching and rewriting phase of
107 /// the bytecode to store ranges of values.
108 std::vector
<ValueRange
> valueRangeMemory
;
109 /// A set of value ranges that have been allocated by the byte code
110 /// interpreter to provide a guaranteed lifetime.
111 std::vector
<llvm::OwningArrayRef
<Value
>> allocatedValueRangeMemory
;
113 /// The current index of ranges being iterated over for each level of nesting.
114 /// These are always maintained at 0 for the loops that are not active, so we
115 /// do not need to have a separate initialization phase for each loop.
116 std::vector
<unsigned> loopIndex
;
118 /// The up-to-date benefits of the patterns held by the bytecode. The order
119 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
120 std::vector
<PatternBenefit
> currentPatternBenefits
;
123 //===----------------------------------------------------------------------===//
125 //===----------------------------------------------------------------------===//
127 /// The bytecode class is also the interpreter. Contains the bytecode itself,
128 /// the static info, addresses of the rewriter functions, the interpreter
129 /// memory buffer, and the execution context.
132 /// Each successful match returns a MatchResult, which contains information
133 /// necessary to execute the rewriter and indicates the originating pattern.
135 MatchResult(Location loc
, const PDLByteCodePattern
&pattern
,
136 PatternBenefit benefit
)
137 : location(loc
), pattern(&pattern
), benefit(benefit
) {}
138 MatchResult(const MatchResult
&) = delete;
139 MatchResult
&operator=(const MatchResult
&) = delete;
140 MatchResult(MatchResult
&&other
) = default;
141 MatchResult
&operator=(MatchResult
&&) = default;
143 /// The location of operations to be replaced.
145 /// Memory values defined in the matcher that are passed to the rewriter.
146 SmallVector
<const void *> values
;
147 /// Memory used for the range input values.
148 SmallVector
<TypeRange
, 0> typeRangeValues
;
149 SmallVector
<ValueRange
, 0> valueRangeValues
;
151 /// The originating pattern that was matched. This is always non-null, but
152 /// represented with a pointer to allow for assignment.
153 const PDLByteCodePattern
*pattern
;
154 /// The current benefit of the pattern that was matched.
155 PatternBenefit benefit
;
158 /// Create a ByteCode instance from the given module containing operations in
159 /// the PDL interpreter dialect.
160 PDLByteCode(ModuleOp module
,
161 SmallVector
<std::unique_ptr
<PDLPatternConfigSet
>> configs
,
162 const DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
,
163 llvm::StringMap
<PDLConstraintFunction
> constraintFns
,
164 llvm::StringMap
<PDLRewriteFunction
> rewriteFns
);
166 /// Return the patterns held by the bytecode.
167 ArrayRef
<PDLByteCodePattern
> getPatterns() const { return patterns
; }
169 /// Initialize the given state such that it can be used to execute the current
171 void initializeMutableState(PDLByteCodeMutableState
&state
) const;
173 /// Run the pattern matcher on the given root operation, collecting the
174 /// matched patterns in `matches`.
175 void match(Operation
*op
, PatternRewriter
&rewriter
,
176 SmallVectorImpl
<MatchResult
> &matches
,
177 PDLByteCodeMutableState
&state
) const;
179 /// Run the rewriter of the given pattern that was previously matched in
180 /// `match`. Returns if a failure was encountered during the rewrite.
181 LogicalResult
rewrite(PatternRewriter
&rewriter
, const MatchResult
&match
,
182 PDLByteCodeMutableState
&state
) const;
185 /// Execute the given byte code starting at the provided instruction `inst`.
186 /// `matches` is an optional field provided when this function is executed in
187 /// a matching context.
188 void executeByteCode(const ByteCodeField
*inst
, PatternRewriter
&rewriter
,
189 PDLByteCodeMutableState
&state
,
190 SmallVectorImpl
<MatchResult
> *matches
) const;
192 /// The set of pattern configs referenced within the bytecode.
193 SmallVector
<std::unique_ptr
<PDLPatternConfigSet
>> configs
;
195 /// A vector containing pointers to uniqued data. The storage is intentionally
196 /// opaque such that we can store a wide range of data types. The types of
197 /// data stored here include:
198 /// * Attribute, OperationName, Type
199 std::vector
<const void *> uniquedData
;
201 /// A vector containing the generated bytecode for the matcher.
202 SmallVector
<ByteCodeField
, 64> matcherByteCode
;
204 /// A vector containing the generated bytecode for all of the rewriters.
205 SmallVector
<ByteCodeField
, 64> rewriterByteCode
;
207 /// The set of patterns contained within the bytecode.
208 SmallVector
<PDLByteCodePattern
, 32> patterns
;
210 /// A set of user defined functions invoked via PDL.
211 std::vector
<PDLConstraintFunction
> constraintFunctions
;
212 std::vector
<PDLRewriteFunction
> rewriteFunctions
;
214 /// The maximum memory index used by a value.
215 ByteCodeField maxValueMemoryIndex
= 0;
217 /// The maximum number of different types of ranges.
218 ByteCodeField maxOpRangeCount
= 0;
219 ByteCodeField maxTypeRangeCount
= 0;
220 ByteCodeField maxValueRangeCount
= 0;
222 /// The maximum number of nested loops.
223 ByteCodeField maxLoopLevel
= 0;
226 } // namespace detail
231 namespace mlir::detail
{
233 class PDLByteCodeMutableState
{
235 void cleanupAfterMatchAndRewrite() {}
236 void updatePatternBenefit(unsigned patternIndex
, PatternBenefit benefit
) {}
239 class PDLByteCodePattern
: public Pattern
{};
244 const PDLByteCodePattern
*pattern
= nullptr;
245 PatternBenefit benefit
;
248 void initializeMutableState(PDLByteCodeMutableState
&state
) const {}
249 void match(Operation
*op
, PatternRewriter
&rewriter
,
250 SmallVectorImpl
<MatchResult
> &matches
,
251 PDLByteCodeMutableState
&state
) const {}
252 LogicalResult
rewrite(PatternRewriter
&rewriter
, const MatchResult
&match
,
253 PDLByteCodeMutableState
&state
) const {
256 ArrayRef
<PDLByteCodePattern
> getPatterns() const { return {}; }
259 } // namespace mlir::detail
261 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
263 #endif // MLIR_REWRITE_BYTECODE_H_