1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements MLIR to byte-code generation and the interpreter.
11 //===----------------------------------------------------------------------===//
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
31 using namespace mlir::detail
;
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 PDLByteCodePattern
PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp
,
38 PDLPatternConfigSet
*configSet
,
39 ByteCodeAddr rewriterAddr
) {
40 PatternBenefit benefit
= matchOp
.getBenefit();
41 MLIRContext
*ctx
= matchOp
.getContext();
43 // Collect the set of generated operations.
44 SmallVector
<StringRef
, 8> generatedOps
;
45 if (ArrayAttr generatedOpsAttr
= matchOp
.getGeneratedOpsAttr())
47 llvm::to_vector
<8>(generatedOpsAttr
.getAsValueRange
<StringAttr
>());
49 // Check to see if this is pattern matches a specific operation type.
50 if (std::optional
<StringRef
> rootKind
= matchOp
.getRootKind())
51 return PDLByteCodePattern(rewriterAddr
, configSet
, *rootKind
, benefit
, ctx
,
53 return PDLByteCodePattern(rewriterAddr
, configSet
, MatchAnyOpTypeTag(),
54 benefit
, ctx
, generatedOps
);
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex
,
65 PatternBenefit benefit
) {
66 currentPatternBenefits
[patternIndex
] = benefit
;
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
72 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73 allocatedTypeRangeMemory
.clear();
74 allocatedValueRangeMemory
.clear();
77 //===----------------------------------------------------------------------===//
79 //===----------------------------------------------------------------------===//
82 enum OpCode
: ByteCodeField
{
83 /// Apply an externally registered constraint.
85 /// Apply an externally registered rewrite.
87 /// Check if two generic values are equal.
89 /// Check if two ranges are equal.
91 /// Unconditional branch.
93 /// Compare the operand count of an operation with a constant.
95 /// Compare the name of an operation with a constant.
97 /// Compare the result count of an operation with a constant.
99 /// Compare a range of types to a constant range of types.
101 /// Continue to the next iteration of a loop.
103 /// Create a type range from a list of constant types.
104 CreateConstantTypeRange
,
105 /// Create an operation.
107 /// Create a type range from a list of dynamic types.
108 CreateDynamicTypeRange
,
109 /// Create a value range.
110 CreateDynamicValueRange
,
111 /// Erase an operation.
113 /// Extract the op from a range at the specified index.
115 /// Extract the type from a range at the specified index.
117 /// Extract the value from a range at the specified index.
119 /// Terminate a matcher or rewrite sequence.
121 /// Iterate over a range of values.
123 /// Get a specific attribute of an operation.
125 /// Get the type of an attribute.
127 /// Get the defining operation of a value.
129 /// Get a specific operand of an operation.
135 /// Get a specific operand group of an operation.
137 /// Get a specific result of an operation.
143 /// Get a specific result group of an operation.
145 /// Get the users of a value or a range of values.
147 /// Get the type of a value.
149 /// Get the types of a value range.
151 /// Check if a generic value is not null.
153 /// Record a successful pattern match.
155 /// Replace an operation.
157 /// Compare an attribute with a set of constants.
159 /// Compare the operand count of an operation with a set of constants.
161 /// Compare the name of an operation with a set of constants.
163 /// Compare the result count of an operation with a set of constants.
165 /// Compare a type with a set of constants.
167 /// Compare a range of types with a set of constants.
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker
=
174 std::numeric_limits
<ByteCodeField
>::max();
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
180 //===----------------------------------------------------------------------===//
184 struct ByteCodeLiveRange
;
185 struct ByteCodeWriter
;
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T
, typename
... Args
>
189 using has_pointer_traits
= decltype(std::declval
<T
>().getAsOpaquePointer());
191 /// This class represents the main generator for the pattern bytecode.
194 Generator(MLIRContext
*ctx
, std::vector
<const void *> &uniquedData
,
195 SmallVectorImpl
<ByteCodeField
> &matcherByteCode
,
196 SmallVectorImpl
<ByteCodeField
> &rewriterByteCode
,
197 SmallVectorImpl
<PDLByteCodePattern
> &patterns
,
198 ByteCodeField
&maxValueMemoryIndex
,
199 ByteCodeField
&maxOpRangeMemoryIndex
,
200 ByteCodeField
&maxTypeRangeMemoryIndex
,
201 ByteCodeField
&maxValueRangeMemoryIndex
,
202 ByteCodeField
&maxLoopLevel
,
203 llvm::StringMap
<PDLConstraintFunction
> &constraintFns
,
204 llvm::StringMap
<PDLRewriteFunction
> &rewriteFns
,
205 const DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
)
206 : ctx(ctx
), uniquedData(uniquedData
), matcherByteCode(matcherByteCode
),
207 rewriterByteCode(rewriterByteCode
), patterns(patterns
),
208 maxValueMemoryIndex(maxValueMemoryIndex
),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex
),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex
),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex
),
212 maxLoopLevel(maxLoopLevel
), configMap(configMap
) {
213 for (const auto &it
: llvm::enumerate(constraintFns
))
214 constraintToMemIndex
.try_emplace(it
.value().first(), it
.index());
215 for (const auto &it
: llvm::enumerate(rewriteFns
))
216 externalRewriterToMemIndex
.try_emplace(it
.value().first(), it
.index());
219 /// Generate the bytecode for the given PDL interpreter module.
220 void generate(ModuleOp module
);
222 /// Return the memory index to use for the given value.
223 ByteCodeField
&getMemIndex(Value value
) {
224 assert(valueToMemIndex
.count(value
) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex
[value
];
229 /// Return the range memory index used to store the given range value.
230 ByteCodeField
&getRangeStorageIndex(Value value
) {
231 assert(valueToRangeIndex
.count(value
) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex
[value
];
236 /// Return an index to use when referring to the given data that is uniqued in
237 /// the MLIR context.
238 template <typename T
>
239 std::enable_if_t
<!std::is_convertible
<T
, Value
>::value
, ByteCodeField
&>
241 const void *opaqueVal
= val
.getAsOpaquePointer();
243 // Get or insert a reference to this value.
244 auto it
= uniquedDataToMemIndex
.try_emplace(
245 opaqueVal
, maxValueMemoryIndex
+ uniquedData
.size());
247 uniquedData
.push_back(opaqueVal
);
248 return it
.first
->second
;
252 /// Allocate memory indices for the results of operations within the matcher
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc
,
255 ModuleOp rewriterModule
);
257 /// Generate the bytecode for the given operation.
258 void generate(Region
*region
, ByteCodeWriter
&writer
);
259 void generate(Operation
*op
, ByteCodeWriter
&writer
);
260 void generate(pdl_interp::ApplyConstraintOp op
, ByteCodeWriter
&writer
);
261 void generate(pdl_interp::ApplyRewriteOp op
, ByteCodeWriter
&writer
);
262 void generate(pdl_interp::AreEqualOp op
, ByteCodeWriter
&writer
);
263 void generate(pdl_interp::BranchOp op
, ByteCodeWriter
&writer
);
264 void generate(pdl_interp::CheckAttributeOp op
, ByteCodeWriter
&writer
);
265 void generate(pdl_interp::CheckOperandCountOp op
, ByteCodeWriter
&writer
);
266 void generate(pdl_interp::CheckOperationNameOp op
, ByteCodeWriter
&writer
);
267 void generate(pdl_interp::CheckResultCountOp op
, ByteCodeWriter
&writer
);
268 void generate(pdl_interp::CheckTypeOp op
, ByteCodeWriter
&writer
);
269 void generate(pdl_interp::CheckTypesOp op
, ByteCodeWriter
&writer
);
270 void generate(pdl_interp::ContinueOp op
, ByteCodeWriter
&writer
);
271 void generate(pdl_interp::CreateAttributeOp op
, ByteCodeWriter
&writer
);
272 void generate(pdl_interp::CreateOperationOp op
, ByteCodeWriter
&writer
);
273 void generate(pdl_interp::CreateRangeOp op
, ByteCodeWriter
&writer
);
274 void generate(pdl_interp::CreateTypeOp op
, ByteCodeWriter
&writer
);
275 void generate(pdl_interp::CreateTypesOp op
, ByteCodeWriter
&writer
);
276 void generate(pdl_interp::EraseOp op
, ByteCodeWriter
&writer
);
277 void generate(pdl_interp::ExtractOp op
, ByteCodeWriter
&writer
);
278 void generate(pdl_interp::FinalizeOp op
, ByteCodeWriter
&writer
);
279 void generate(pdl_interp::ForEachOp op
, ByteCodeWriter
&writer
);
280 void generate(pdl_interp::GetAttributeOp op
, ByteCodeWriter
&writer
);
281 void generate(pdl_interp::GetAttributeTypeOp op
, ByteCodeWriter
&writer
);
282 void generate(pdl_interp::GetDefiningOpOp op
, ByteCodeWriter
&writer
);
283 void generate(pdl_interp::GetOperandOp op
, ByteCodeWriter
&writer
);
284 void generate(pdl_interp::GetOperandsOp op
, ByteCodeWriter
&writer
);
285 void generate(pdl_interp::GetResultOp op
, ByteCodeWriter
&writer
);
286 void generate(pdl_interp::GetResultsOp op
, ByteCodeWriter
&writer
);
287 void generate(pdl_interp::GetUsersOp op
, ByteCodeWriter
&writer
);
288 void generate(pdl_interp::GetValueTypeOp op
, ByteCodeWriter
&writer
);
289 void generate(pdl_interp::IsNotNullOp op
, ByteCodeWriter
&writer
);
290 void generate(pdl_interp::RecordMatchOp op
, ByteCodeWriter
&writer
);
291 void generate(pdl_interp::ReplaceOp op
, ByteCodeWriter
&writer
);
292 void generate(pdl_interp::SwitchAttributeOp op
, ByteCodeWriter
&writer
);
293 void generate(pdl_interp::SwitchTypeOp op
, ByteCodeWriter
&writer
);
294 void generate(pdl_interp::SwitchTypesOp op
, ByteCodeWriter
&writer
);
295 void generate(pdl_interp::SwitchOperandCountOp op
, ByteCodeWriter
&writer
);
296 void generate(pdl_interp::SwitchOperationNameOp op
, ByteCodeWriter
&writer
);
297 void generate(pdl_interp::SwitchResultCountOp op
, ByteCodeWriter
&writer
);
299 /// Mapping from value to its corresponding memory index.
300 DenseMap
<Value
, ByteCodeField
> valueToMemIndex
;
302 /// Mapping from a range value to its corresponding range storage index.
303 DenseMap
<Value
, ByteCodeField
> valueToRangeIndex
;
305 /// Mapping from the name of an externally registered rewrite to its index in
306 /// the bytecode registry.
307 llvm::StringMap
<ByteCodeField
> externalRewriterToMemIndex
;
309 /// Mapping from the name of an externally registered constraint to its index
310 /// in the bytecode registry.
311 llvm::StringMap
<ByteCodeField
> constraintToMemIndex
;
313 /// Mapping from rewriter function name to the bytecode address of the
314 /// rewriter function in byte.
315 llvm::StringMap
<ByteCodeAddr
> rewriterToAddr
;
317 /// Mapping from a uniqued storage object to its memory index within
319 DenseMap
<const void *, ByteCodeField
> uniquedDataToMemIndex
;
321 /// The current level of the foreach loop.
322 ByteCodeField curLoopLevel
= 0;
324 /// The current MLIR context.
327 /// Mapping from block to its address.
328 DenseMap
<Block
*, ByteCodeAddr
> blockToAddr
;
330 /// Data of the ByteCode class to be populated.
331 std::vector
<const void *> &uniquedData
;
332 SmallVectorImpl
<ByteCodeField
> &matcherByteCode
;
333 SmallVectorImpl
<ByteCodeField
> &rewriterByteCode
;
334 SmallVectorImpl
<PDLByteCodePattern
> &patterns
;
335 ByteCodeField
&maxValueMemoryIndex
;
336 ByteCodeField
&maxOpRangeMemoryIndex
;
337 ByteCodeField
&maxTypeRangeMemoryIndex
;
338 ByteCodeField
&maxValueRangeMemoryIndex
;
339 ByteCodeField
&maxLoopLevel
;
341 /// A map of pattern configurations.
342 const DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
;
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter
{
347 ByteCodeWriter(SmallVectorImpl
<ByteCodeField
> &bytecode
, Generator
&generator
)
348 : bytecode(bytecode
), generator(generator
) {}
350 /// Append a field to the bytecode.
351 void append(ByteCodeField field
) { bytecode
.push_back(field
); }
352 void append(OpCode opCode
) { bytecode
.push_back(opCode
); }
354 /// Append an address to the bytecode.
355 void append(ByteCodeAddr field
) {
356 static_assert((sizeof(ByteCodeAddr
) / sizeof(ByteCodeField
)) == 2,
357 "unexpected ByteCode address size");
359 ByteCodeField fieldParts
[2];
360 std::memcpy(fieldParts
, &field
, sizeof(ByteCodeAddr
));
361 bytecode
.append({fieldParts
[0], fieldParts
[1]});
364 /// Append a single successor to the bytecode, the exact address will need to
365 /// be resolved later.
366 void append(Block
*successor
) {
367 // Add back a reference to the successor so that the address can be resolved
369 unresolvedSuccessorRefs
[successor
].push_back(bytecode
.size());
370 append(ByteCodeAddr(0));
373 /// Append a successor range to the bytecode, the exact address will need to
374 /// be resolved later.
375 void append(SuccessorRange successors
) {
376 for (Block
*successor
: successors
)
380 /// Append a range of values that will be read as generic PDLValues.
381 void appendPDLValueList(OperandRange values
) {
382 bytecode
.push_back(values
.size());
383 for (Value value
: values
)
384 appendPDLValue(value
);
387 /// Append a value as a PDLValue.
388 void appendPDLValue(Value value
) {
389 appendPDLValueKind(value
);
393 /// Append the PDLValue::Kind of the given value.
394 void appendPDLValueKind(Value value
) { appendPDLValueKind(value
.getType()); }
396 /// Append the PDLValue::Kind of the given type.
397 void appendPDLValueKind(Type type
) {
398 PDLValue::Kind kind
=
399 TypeSwitch
<Type
, PDLValue::Kind
>(type
)
400 .Case
<pdl::AttributeType
>(
401 [](Type
) { return PDLValue::Kind::Attribute
; })
402 .Case
<pdl::OperationType
>(
403 [](Type
) { return PDLValue::Kind::Operation
; })
404 .Case
<pdl::RangeType
>([](pdl::RangeType rangeTy
) {
405 if (isa
<pdl::TypeType
>(rangeTy
.getElementType()))
406 return PDLValue::Kind::TypeRange
;
407 return PDLValue::Kind::ValueRange
;
409 .Case
<pdl::TypeType
>([](Type
) { return PDLValue::Kind::Type
; })
410 .Case
<pdl::ValueType
>([](Type
) { return PDLValue::Kind::Value
; });
411 bytecode
.push_back(static_cast<ByteCodeField
>(kind
));
414 /// Append a value that will be stored in a memory slot and not inline within
416 template <typename T
>
417 std::enable_if_t
<llvm::is_detected
<has_pointer_traits
, T
>::value
||
418 std::is_pointer
<T
>::value
>
420 bytecode
.push_back(generator
.getMemIndex(value
));
423 /// Append a range of values.
424 template <typename T
, typename IteratorT
= llvm::detail::IterOfRange
<T
>>
425 std::enable_if_t
<!llvm::is_detected
<has_pointer_traits
, T
>::value
>
427 bytecode
.push_back(llvm::size(range
));
428 for (auto it
: range
)
432 /// Append a variadic number of fields to the bytecode.
433 template <typename FieldTy
, typename Field2Ty
, typename
... FieldTys
>
434 void append(FieldTy field
, Field2Ty field2
, FieldTys
... fields
) {
436 append(field2
, fields
...);
439 /// Appends a value as a pointer, stored inline within the bytecode.
440 template <typename T
>
441 std::enable_if_t
<llvm::is_detected
<has_pointer_traits
, T
>::value
>
442 appendInline(T value
) {
443 constexpr size_t numParts
= sizeof(const void *) / sizeof(ByteCodeField
);
444 const void *pointer
= value
.getAsOpaquePointer();
445 ByteCodeField fieldParts
[numParts
];
446 std::memcpy(fieldParts
, &pointer
, sizeof(const void *));
447 bytecode
.append(fieldParts
, fieldParts
+ numParts
);
450 /// Successor references in the bytecode that have yet to be resolved.
451 DenseMap
<Block
*, SmallVector
<unsigned, 4>> unresolvedSuccessorRefs
;
453 /// The underlying bytecode buffer.
454 SmallVectorImpl
<ByteCodeField
> &bytecode
;
456 /// The main generator producing PDL.
457 Generator
&generator
;
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange
{
463 using Set
= llvm::IntervalMap
<uint64_t, char, 16>;
464 using Allocator
= Set::Allocator
;
466 ByteCodeLiveRange(Allocator
&alloc
) : liveness(new Set(alloc
)) {}
468 /// Union this live range with the one provided.
469 void unionWith(const ByteCodeLiveRange
&rhs
) {
470 for (auto it
= rhs
.liveness
->begin(), e
= rhs
.liveness
->end(); it
!= e
;
472 liveness
->insert(it
.start(), it
.stop(), /*dummyValue*/ 0);
475 /// Returns true if this range overlaps with the one provided.
476 bool overlaps(const ByteCodeLiveRange
&rhs
) const {
477 return llvm::IntervalMapOverlaps
<Set
, Set
>(*liveness
, *rhs
.liveness
)
481 /// A map representing the ranges of the match/rewrite that a value is live in
484 /// We use std::unique_ptr here, because IntervalMap does not provide a
485 /// correct copy or move constructor. We can eliminate the pointer once
486 /// https://reviews.llvm.org/D113240 lands.
487 std::unique_ptr
<llvm::IntervalMap
<uint64_t, char, 16>> liveness
;
489 /// The operation range storage index for this range.
490 std::optional
<unsigned> opRangeIndex
;
492 /// The type range storage index for this range.
493 std::optional
<unsigned> typeRangeIndex
;
495 /// The value range storage index for this range.
496 std::optional
<unsigned> valueRangeIndex
;
500 void Generator::generate(ModuleOp module
) {
501 auto matcherFunc
= module
.lookupSymbol
<pdl_interp::FuncOp
>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule
= module
.lookupSymbol
<ModuleOp
>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc
&& rewriterModule
&& "invalid PDL Interpreter module");
507 // Allocate memory indices for the results of operations within the matcher
509 allocateMemoryIndices(matcherFunc
, rewriterModule
);
511 // Generate code for the rewriter functions.
512 ByteCodeWriter
rewriterByteCodeWriter(rewriterByteCode
, *this);
513 for (auto rewriterFunc
: rewriterModule
.getOps
<pdl_interp::FuncOp
>()) {
514 rewriterToAddr
.try_emplace(rewriterFunc
.getName(), rewriterByteCode
.size());
515 for (Operation
&op
: rewriterFunc
.getOps())
516 generate(&op
, rewriterByteCodeWriter
);
518 assert(rewriterByteCodeWriter
.unresolvedSuccessorRefs
.empty() &&
519 "unexpected branches in rewriter function");
521 // Generate code for the matcher function.
522 ByteCodeWriter
matcherByteCodeWriter(matcherByteCode
, *this);
523 generate(&matcherFunc
.getBody(), matcherByteCodeWriter
);
525 // Resolve successor references in the matcher.
526 for (auto &it
: matcherByteCodeWriter
.unresolvedSuccessorRefs
) {
527 ByteCodeAddr addr
= blockToAddr
[it
.first
];
528 for (unsigned offsetToFix
: it
.second
)
529 std::memcpy(&matcherByteCode
[offsetToFix
], &addr
, sizeof(ByteCodeAddr
));
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc
,
534 ModuleOp rewriterModule
) {
535 // Rewriters use simplistic allocation scheme that simply assigns an index to
537 for (auto rewriterFunc
: rewriterModule
.getOps
<pdl_interp::FuncOp
>()) {
538 ByteCodeField index
= 0, typeRangeIndex
= 0, valueRangeIndex
= 0;
539 auto processRewriterValue
= [&](Value val
) {
540 valueToMemIndex
.try_emplace(val
, index
++);
541 if (pdl::RangeType rangeType
= dyn_cast
<pdl::RangeType
>(val
.getType())) {
542 Type elementTy
= rangeType
.getElementType();
543 if (isa
<pdl::TypeType
>(elementTy
))
544 valueToRangeIndex
.try_emplace(val
, typeRangeIndex
++);
545 else if (isa
<pdl::ValueType
>(elementTy
))
546 valueToRangeIndex
.try_emplace(val
, valueRangeIndex
++);
550 for (BlockArgument arg
: rewriterFunc
.getArguments())
551 processRewriterValue(arg
);
552 rewriterFunc
.getBody().walk([&](Operation
*op
) {
553 for (Value result
: op
->getResults())
554 processRewriterValue(result
);
556 if (index
> maxValueMemoryIndex
)
557 maxValueMemoryIndex
= index
;
558 if (typeRangeIndex
> maxTypeRangeMemoryIndex
)
559 maxTypeRangeMemoryIndex
= typeRangeIndex
;
560 if (valueRangeIndex
> maxValueRangeMemoryIndex
)
561 maxValueRangeMemoryIndex
= valueRangeIndex
;
564 // The matcher function uses a more sophisticated numbering that tries to
565 // minimize the number of memory indices assigned. This is done by determining
566 // a live range of the values within the matcher, then the allocation is just
567 // finding the minimal number of overlapping live ranges. This is essentially
568 // a simplified form of register allocation where we don't necessarily have a
569 // limited number of registers, but we still want to minimize the number used.
570 DenseMap
<Operation
*, unsigned> opToFirstIndex
;
571 DenseMap
<Operation
*, unsigned> opToLastIndex
;
573 // A custom walk that marks the first and the last index of each operation.
574 // The entry marks the beginning of the liveness range for this operation,
575 // followed by nested operations, followed by the end of the liveness range.
577 llvm::unique_function
<void(Operation
*)> walk
= [&](Operation
*op
) {
578 opToFirstIndex
.try_emplace(op
, index
++);
579 for (Region
®ion
: op
->getRegions())
580 for (Block
&block
: region
.getBlocks())
581 for (Operation
&nested
: block
)
583 opToLastIndex
.try_emplace(op
, index
++);
587 // Liveness info for each of the defs within the matcher.
588 ByteCodeLiveRange::Allocator allocator
;
589 DenseMap
<Value
, ByteCodeLiveRange
> valueDefRanges
;
591 // Assign the root operation being matched to slot 0.
592 BlockArgument rootOpArg
= matcherFunc
.getArgument(0);
593 valueToMemIndex
[rootOpArg
] = 0;
595 // Walk each of the blocks, computing the def interval that the value is used.
596 Liveness
matcherLiveness(matcherFunc
);
597 matcherFunc
->walk([&](Block
*block
) {
598 const LivenessBlockInfo
*info
= matcherLiveness
.getLiveness(block
);
599 assert(info
&& "expected liveness info for block");
600 auto processValue
= [&](Value value
, Operation
*firstUseOrDef
) {
601 // We don't need to process the root op argument, this value is always
602 // assigned to the first memory slot.
603 if (value
== rootOpArg
)
606 // Set indices for the range of this block that the value is used.
607 auto defRangeIt
= valueDefRanges
.try_emplace(value
, allocator
).first
;
608 defRangeIt
->second
.liveness
->insert(
609 opToFirstIndex
[firstUseOrDef
],
610 opToLastIndex
[info
->getEndOperation(value
, firstUseOrDef
)],
613 // Check to see if this value is a range type.
614 if (auto rangeTy
= dyn_cast
<pdl::RangeType
>(value
.getType())) {
615 Type eleType
= rangeTy
.getElementType();
616 if (isa
<pdl::OperationType
>(eleType
))
617 defRangeIt
->second
.opRangeIndex
= 0;
618 else if (isa
<pdl::TypeType
>(eleType
))
619 defRangeIt
->second
.typeRangeIndex
= 0;
620 else if (isa
<pdl::ValueType
>(eleType
))
621 defRangeIt
->second
.valueRangeIndex
= 0;
625 // Process the live-ins of this block.
626 for (Value liveIn
: info
->in()) {
627 // Only process the value if it has been defined in the current region.
628 // Other values that span across pdl_interp.foreach will be added higher
629 // up. This ensures that the we keep them alive for the entire duration
631 if (liveIn
.getParentRegion() == block
->getParent())
632 processValue(liveIn
, &block
->front());
635 // Process the block arguments for the entry block (those are not live-in).
636 if (block
->isEntryBlock()) {
637 for (Value argument
: block
->getArguments())
638 processValue(argument
, &block
->front());
641 // Process any new defs within this block.
642 for (Operation
&op
: *block
)
643 for (Value result
: op
.getResults())
644 processValue(result
, &op
);
647 // Greedily allocate memory slots using the computed def live ranges.
648 std::vector
<ByteCodeLiveRange
> allocatedIndices
;
650 // The number of memory indices currently allocated (and its next value).
651 // Recall that the root gets allocated memory index 0.
652 ByteCodeField numIndices
= 1;
654 // The number of memory ranges of various types (and their next values).
655 ByteCodeField numOpRanges
= 0, numTypeRanges
= 0, numValueRanges
= 0;
657 for (auto &defIt
: valueDefRanges
) {
658 ByteCodeField
&memIndex
= valueToMemIndex
[defIt
.first
];
659 ByteCodeLiveRange
&defRange
= defIt
.second
;
661 // Try to allocate to an existing index.
662 for (const auto &existingIndexIt
: llvm::enumerate(allocatedIndices
)) {
663 ByteCodeLiveRange
&existingRange
= existingIndexIt
.value();
664 if (!defRange
.overlaps(existingRange
)) {
665 existingRange
.unionWith(defRange
);
666 memIndex
= existingIndexIt
.index() + 1;
668 if (defRange
.opRangeIndex
) {
669 if (!existingRange
.opRangeIndex
)
670 existingRange
.opRangeIndex
= numOpRanges
++;
671 valueToRangeIndex
[defIt
.first
] = *existingRange
.opRangeIndex
;
672 } else if (defRange
.typeRangeIndex
) {
673 if (!existingRange
.typeRangeIndex
)
674 existingRange
.typeRangeIndex
= numTypeRanges
++;
675 valueToRangeIndex
[defIt
.first
] = *existingRange
.typeRangeIndex
;
676 } else if (defRange
.valueRangeIndex
) {
677 if (!existingRange
.valueRangeIndex
)
678 existingRange
.valueRangeIndex
= numValueRanges
++;
679 valueToRangeIndex
[defIt
.first
] = *existingRange
.valueRangeIndex
;
685 // If no existing index could be used, add a new one.
687 allocatedIndices
.emplace_back(allocator
);
688 ByteCodeLiveRange
&newRange
= allocatedIndices
.back();
689 newRange
.unionWith(defRange
);
691 // Allocate an index for op/type/value ranges.
692 if (defRange
.opRangeIndex
) {
693 newRange
.opRangeIndex
= numOpRanges
;
694 valueToRangeIndex
[defIt
.first
] = numOpRanges
++;
695 } else if (defRange
.typeRangeIndex
) {
696 newRange
.typeRangeIndex
= numTypeRanges
;
697 valueToRangeIndex
[defIt
.first
] = numTypeRanges
++;
698 } else if (defRange
.valueRangeIndex
) {
699 newRange
.valueRangeIndex
= numValueRanges
;
700 valueToRangeIndex
[defIt
.first
] = numValueRanges
++;
703 memIndex
= allocatedIndices
.size();
708 // Print the index usage and ensure that we did not run out of index space.
710 llvm::dbgs() << "Allocated " << allocatedIndices
.size() << " indices "
711 << "(down from initial " << valueDefRanges
.size() << ").\n";
713 assert(allocatedIndices
.size() <= std::numeric_limits
<ByteCodeField
>::max() &&
714 "Ran out of memory for allocated indices");
716 // Update the max number of indices.
717 if (numIndices
> maxValueMemoryIndex
)
718 maxValueMemoryIndex
= numIndices
;
719 if (numOpRanges
> maxOpRangeMemoryIndex
)
720 maxOpRangeMemoryIndex
= numOpRanges
;
721 if (numTypeRanges
> maxTypeRangeMemoryIndex
)
722 maxTypeRangeMemoryIndex
= numTypeRanges
;
723 if (numValueRanges
> maxValueRangeMemoryIndex
)
724 maxValueRangeMemoryIndex
= numValueRanges
;
727 void Generator::generate(Region
*region
, ByteCodeWriter
&writer
) {
728 llvm::ReversePostOrderTraversal
<Region
*> rpot(region
);
729 for (Block
*block
: rpot
) {
730 // Keep track of where this block begins within the matcher function.
731 blockToAddr
.try_emplace(block
, matcherByteCode
.size());
732 for (Operation
&op
: *block
)
733 generate(&op
, writer
);
737 void Generator::generate(Operation
*op
, ByteCodeWriter
&writer
) {
739 // The following list must contain all the operations that do not
740 // produce any bytecode.
741 if (!isa
<pdl_interp::CreateAttributeOp
, pdl_interp::CreateTypeOp
>(op
))
742 writer
.appendInline(op
->getLoc());
744 TypeSwitch
<Operation
*>(op
)
745 .Case
<pdl_interp::ApplyConstraintOp
, pdl_interp::ApplyRewriteOp
,
746 pdl_interp::AreEqualOp
, pdl_interp::BranchOp
,
747 pdl_interp::CheckAttributeOp
, pdl_interp::CheckOperandCountOp
,
748 pdl_interp::CheckOperationNameOp
, pdl_interp::CheckResultCountOp
,
749 pdl_interp::CheckTypeOp
, pdl_interp::CheckTypesOp
,
750 pdl_interp::ContinueOp
, pdl_interp::CreateAttributeOp
,
751 pdl_interp::CreateOperationOp
, pdl_interp::CreateRangeOp
,
752 pdl_interp::CreateTypeOp
, pdl_interp::CreateTypesOp
,
753 pdl_interp::EraseOp
, pdl_interp::ExtractOp
, pdl_interp::FinalizeOp
,
754 pdl_interp::ForEachOp
, pdl_interp::GetAttributeOp
,
755 pdl_interp::GetAttributeTypeOp
, pdl_interp::GetDefiningOpOp
,
756 pdl_interp::GetOperandOp
, pdl_interp::GetOperandsOp
,
757 pdl_interp::GetResultOp
, pdl_interp::GetResultsOp
,
758 pdl_interp::GetUsersOp
, pdl_interp::GetValueTypeOp
,
759 pdl_interp::IsNotNullOp
, pdl_interp::RecordMatchOp
,
760 pdl_interp::ReplaceOp
, pdl_interp::SwitchAttributeOp
,
761 pdl_interp::SwitchTypeOp
, pdl_interp::SwitchTypesOp
,
762 pdl_interp::SwitchOperandCountOp
, pdl_interp::SwitchOperationNameOp
,
763 pdl_interp::SwitchResultCountOp
>(
764 [&](auto interpOp
) { this->generate(interpOp
, writer
); })
765 .Default([](Operation
*) {
766 llvm_unreachable("unknown `pdl_interp` operation");
770 void Generator::generate(pdl_interp::ApplyConstraintOp op
,
771 ByteCodeWriter
&writer
) {
772 // Constraints that should return a value have to be registered as rewrites.
773 // If a constraint and a rewrite of similar name are registered the
774 // constraint takes precedence
775 writer
.append(OpCode::ApplyConstraint
, constraintToMemIndex
[op
.getName()]);
776 writer
.appendPDLValueList(op
.getArgs());
777 writer
.append(ByteCodeField(op
.getIsNegated()));
778 ResultRange results
= op
.getResults();
779 writer
.append(ByteCodeField(results
.size()));
780 for (Value result
: results
) {
781 // We record the expected kind of the result, so that we can provide extra
782 // verification of the native rewrite function and handle the failure case
783 // of constraints accordingly.
784 writer
.appendPDLValueKind(result
);
786 // Range results also need to append the range storage index.
787 if (isa
<pdl::RangeType
>(result
.getType()))
788 writer
.append(getRangeStorageIndex(result
));
789 writer
.append(result
);
791 writer
.append(op
.getSuccessors());
793 void Generator::generate(pdl_interp::ApplyRewriteOp op
,
794 ByteCodeWriter
&writer
) {
795 assert(externalRewriterToMemIndex
.count(op
.getName()) &&
796 "expected index for rewrite function");
797 writer
.append(OpCode::ApplyRewrite
, externalRewriterToMemIndex
[op
.getName()]);
798 writer
.appendPDLValueList(op
.getArgs());
800 ResultRange results
= op
.getResults();
801 writer
.append(ByteCodeField(results
.size()));
802 for (Value result
: results
) {
803 // We record the expected kind of the result, so that we
804 // can provide extra verification of the native rewrite function.
805 writer
.appendPDLValueKind(result
);
807 // Range results also need to append the range storage index.
808 if (isa
<pdl::RangeType
>(result
.getType()))
809 writer
.append(getRangeStorageIndex(result
));
810 writer
.append(result
);
813 void Generator::generate(pdl_interp::AreEqualOp op
, ByteCodeWriter
&writer
) {
814 Value lhs
= op
.getLhs();
815 if (isa
<pdl::RangeType
>(lhs
.getType())) {
816 writer
.append(OpCode::AreRangesEqual
);
817 writer
.appendPDLValueKind(lhs
);
818 writer
.append(op
.getLhs(), op
.getRhs(), op
.getSuccessors());
822 writer
.append(OpCode::AreEqual
, lhs
, op
.getRhs(), op
.getSuccessors());
824 void Generator::generate(pdl_interp::BranchOp op
, ByteCodeWriter
&writer
) {
825 writer
.append(OpCode::Branch
, SuccessorRange(op
.getOperation()));
827 void Generator::generate(pdl_interp::CheckAttributeOp op
,
828 ByteCodeWriter
&writer
) {
829 writer
.append(OpCode::AreEqual
, op
.getAttribute(), op
.getConstantValue(),
832 void Generator::generate(pdl_interp::CheckOperandCountOp op
,
833 ByteCodeWriter
&writer
) {
834 writer
.append(OpCode::CheckOperandCount
, op
.getInputOp(), op
.getCount(),
835 static_cast<ByteCodeField
>(op
.getCompareAtLeast()),
838 void Generator::generate(pdl_interp::CheckOperationNameOp op
,
839 ByteCodeWriter
&writer
) {
840 writer
.append(OpCode::CheckOperationName
, op
.getInputOp(),
841 OperationName(op
.getName(), ctx
), op
.getSuccessors());
843 void Generator::generate(pdl_interp::CheckResultCountOp op
,
844 ByteCodeWriter
&writer
) {
845 writer
.append(OpCode::CheckResultCount
, op
.getInputOp(), op
.getCount(),
846 static_cast<ByteCodeField
>(op
.getCompareAtLeast()),
849 void Generator::generate(pdl_interp::CheckTypeOp op
, ByteCodeWriter
&writer
) {
850 writer
.append(OpCode::AreEqual
, op
.getValue(), op
.getType(),
853 void Generator::generate(pdl_interp::CheckTypesOp op
, ByteCodeWriter
&writer
) {
854 writer
.append(OpCode::CheckTypes
, op
.getValue(), op
.getTypes(),
857 void Generator::generate(pdl_interp::ContinueOp op
, ByteCodeWriter
&writer
) {
858 assert(curLoopLevel
> 0 && "encountered pdl_interp.continue at top level");
859 writer
.append(OpCode::Continue
, ByteCodeField(curLoopLevel
- 1));
861 void Generator::generate(pdl_interp::CreateAttributeOp op
,
862 ByteCodeWriter
&writer
) {
863 // Simply repoint the memory index of the result to the constant.
864 getMemIndex(op
.getAttribute()) = getMemIndex(op
.getValue());
866 void Generator::generate(pdl_interp::CreateOperationOp op
,
867 ByteCodeWriter
&writer
) {
868 writer
.append(OpCode::CreateOperation
, op
.getResultOp(),
869 OperationName(op
.getName(), ctx
));
870 writer
.appendPDLValueList(op
.getInputOperands());
872 // Add the attributes.
873 OperandRange attributes
= op
.getInputAttributes();
874 writer
.append(static_cast<ByteCodeField
>(attributes
.size()));
875 for (auto it
: llvm::zip(op
.getInputAttributeNames(), attributes
))
876 writer
.append(std::get
<0>(it
), std::get
<1>(it
));
878 // Add the result types. If the operation has inferred results, we use a
879 // marker "size" value. Otherwise, we add the list of explicit result types.
880 if (op
.getInferredResultTypes())
881 writer
.append(kInferTypesMarker
);
883 writer
.appendPDLValueList(op
.getInputResultTypes());
885 void Generator::generate(pdl_interp::CreateRangeOp op
, ByteCodeWriter
&writer
) {
886 // Append the correct opcode for the range type.
887 TypeSwitch
<Type
>(op
.getType().getElementType())
889 [&](pdl::TypeType
) { writer
.append(OpCode::CreateDynamicTypeRange
); })
890 .Case([&](pdl::ValueType
) {
891 writer
.append(OpCode::CreateDynamicValueRange
);
894 writer
.append(op
.getResult(), getRangeStorageIndex(op
.getResult()));
895 writer
.appendPDLValueList(op
->getOperands());
897 void Generator::generate(pdl_interp::CreateTypeOp op
, ByteCodeWriter
&writer
) {
898 // Simply repoint the memory index of the result to the constant.
899 getMemIndex(op
.getResult()) = getMemIndex(op
.getValue());
901 void Generator::generate(pdl_interp::CreateTypesOp op
, ByteCodeWriter
&writer
) {
902 writer
.append(OpCode::CreateConstantTypeRange
, op
.getResult(),
903 getRangeStorageIndex(op
.getResult()), op
.getValue());
905 void Generator::generate(pdl_interp::EraseOp op
, ByteCodeWriter
&writer
) {
906 writer
.append(OpCode::EraseOp
, op
.getInputOp());
908 void Generator::generate(pdl_interp::ExtractOp op
, ByteCodeWriter
&writer
) {
910 TypeSwitch
<Type
, OpCode
>(op
.getResult().getType())
911 .Case([](pdl::OperationType
) { return OpCode::ExtractOp
; })
912 .Case([](pdl::ValueType
) { return OpCode::ExtractValue
; })
913 .Case([](pdl::TypeType
) { return OpCode::ExtractType
; })
914 .Default([](Type
) -> OpCode
{
915 llvm_unreachable("unsupported element type");
917 writer
.append(opCode
, op
.getRange(), op
.getIndex(), op
.getResult());
919 void Generator::generate(pdl_interp::FinalizeOp op
, ByteCodeWriter
&writer
) {
920 writer
.append(OpCode::Finalize
);
922 void Generator::generate(pdl_interp::ForEachOp op
, ByteCodeWriter
&writer
) {
923 BlockArgument arg
= op
.getLoopVariable();
924 writer
.append(OpCode::ForEach
, getRangeStorageIndex(op
.getValues()), arg
);
925 writer
.appendPDLValueKind(arg
.getType());
926 writer
.append(curLoopLevel
, op
.getSuccessor());
928 if (curLoopLevel
> maxLoopLevel
)
929 maxLoopLevel
= curLoopLevel
;
930 generate(&op
.getRegion(), writer
);
933 void Generator::generate(pdl_interp::GetAttributeOp op
,
934 ByteCodeWriter
&writer
) {
935 writer
.append(OpCode::GetAttribute
, op
.getAttribute(), op
.getInputOp(),
938 void Generator::generate(pdl_interp::GetAttributeTypeOp op
,
939 ByteCodeWriter
&writer
) {
940 writer
.append(OpCode::GetAttributeType
, op
.getResult(), op
.getValue());
942 void Generator::generate(pdl_interp::GetDefiningOpOp op
,
943 ByteCodeWriter
&writer
) {
944 writer
.append(OpCode::GetDefiningOp
, op
.getInputOp());
945 writer
.appendPDLValue(op
.getValue());
947 void Generator::generate(pdl_interp::GetOperandOp op
, ByteCodeWriter
&writer
) {
948 uint32_t index
= op
.getIndex();
950 writer
.append(static_cast<OpCode
>(OpCode::GetOperand0
+ index
));
952 writer
.append(OpCode::GetOperandN
, index
);
953 writer
.append(op
.getInputOp(), op
.getValue());
955 void Generator::generate(pdl_interp::GetOperandsOp op
, ByteCodeWriter
&writer
) {
956 Value result
= op
.getValue();
957 std::optional
<uint32_t> index
= op
.getIndex();
958 writer
.append(OpCode::GetOperands
,
959 index
.value_or(std::numeric_limits
<uint32_t>::max()),
961 if (isa
<pdl::RangeType
>(result
.getType()))
962 writer
.append(getRangeStorageIndex(result
));
964 writer
.append(std::numeric_limits
<ByteCodeField
>::max());
965 writer
.append(result
);
967 void Generator::generate(pdl_interp::GetResultOp op
, ByteCodeWriter
&writer
) {
968 uint32_t index
= op
.getIndex();
970 writer
.append(static_cast<OpCode
>(OpCode::GetResult0
+ index
));
972 writer
.append(OpCode::GetResultN
, index
);
973 writer
.append(op
.getInputOp(), op
.getValue());
975 void Generator::generate(pdl_interp::GetResultsOp op
, ByteCodeWriter
&writer
) {
976 Value result
= op
.getValue();
977 std::optional
<uint32_t> index
= op
.getIndex();
978 writer
.append(OpCode::GetResults
,
979 index
.value_or(std::numeric_limits
<uint32_t>::max()),
981 if (isa
<pdl::RangeType
>(result
.getType()))
982 writer
.append(getRangeStorageIndex(result
));
984 writer
.append(std::numeric_limits
<ByteCodeField
>::max());
985 writer
.append(result
);
987 void Generator::generate(pdl_interp::GetUsersOp op
, ByteCodeWriter
&writer
) {
988 Value operations
= op
.getOperations();
989 ByteCodeField rangeIndex
= getRangeStorageIndex(operations
);
990 writer
.append(OpCode::GetUsers
, operations
, rangeIndex
);
991 writer
.appendPDLValue(op
.getValue());
993 void Generator::generate(pdl_interp::GetValueTypeOp op
,
994 ByteCodeWriter
&writer
) {
995 if (isa
<pdl::RangeType
>(op
.getType())) {
996 Value result
= op
.getResult();
997 writer
.append(OpCode::GetValueRangeTypes
, result
,
998 getRangeStorageIndex(result
), op
.getValue());
1000 writer
.append(OpCode::GetValueType
, op
.getResult(), op
.getValue());
1003 void Generator::generate(pdl_interp::IsNotNullOp op
, ByteCodeWriter
&writer
) {
1004 writer
.append(OpCode::IsNotNull
, op
.getValue(), op
.getSuccessors());
1006 void Generator::generate(pdl_interp::RecordMatchOp op
, ByteCodeWriter
&writer
) {
1007 ByteCodeField patternIndex
= patterns
.size();
1008 patterns
.emplace_back(PDLByteCodePattern::create(
1009 op
, configMap
.lookup(op
),
1010 rewriterToAddr
[op
.getRewriter().getLeafReference().getValue()]));
1011 writer
.append(OpCode::RecordMatch
, patternIndex
,
1012 SuccessorRange(op
.getOperation()), op
.getMatchedOps());
1013 writer
.appendPDLValueList(op
.getInputs());
1015 void Generator::generate(pdl_interp::ReplaceOp op
, ByteCodeWriter
&writer
) {
1016 writer
.append(OpCode::ReplaceOp
, op
.getInputOp());
1017 writer
.appendPDLValueList(op
.getReplValues());
1019 void Generator::generate(pdl_interp::SwitchAttributeOp op
,
1020 ByteCodeWriter
&writer
) {
1021 writer
.append(OpCode::SwitchAttribute
, op
.getAttribute(),
1022 op
.getCaseValuesAttr(), op
.getSuccessors());
1024 void Generator::generate(pdl_interp::SwitchOperandCountOp op
,
1025 ByteCodeWriter
&writer
) {
1026 writer
.append(OpCode::SwitchOperandCount
, op
.getInputOp(),
1027 op
.getCaseValuesAttr(), op
.getSuccessors());
1029 void Generator::generate(pdl_interp::SwitchOperationNameOp op
,
1030 ByteCodeWriter
&writer
) {
1031 auto cases
= llvm::map_range(op
.getCaseValuesAttr(), [&](Attribute attr
) {
1032 return OperationName(cast
<StringAttr
>(attr
).getValue(), ctx
);
1034 writer
.append(OpCode::SwitchOperationName
, op
.getInputOp(), cases
,
1035 op
.getSuccessors());
1037 void Generator::generate(pdl_interp::SwitchResultCountOp op
,
1038 ByteCodeWriter
&writer
) {
1039 writer
.append(OpCode::SwitchResultCount
, op
.getInputOp(),
1040 op
.getCaseValuesAttr(), op
.getSuccessors());
1042 void Generator::generate(pdl_interp::SwitchTypeOp op
, ByteCodeWriter
&writer
) {
1043 writer
.append(OpCode::SwitchType
, op
.getValue(), op
.getCaseValuesAttr(),
1044 op
.getSuccessors());
1046 void Generator::generate(pdl_interp::SwitchTypesOp op
, ByteCodeWriter
&writer
) {
1047 writer
.append(OpCode::SwitchTypes
, op
.getValue(), op
.getCaseValuesAttr(),
1048 op
.getSuccessors());
1051 //===----------------------------------------------------------------------===//
1053 //===----------------------------------------------------------------------===//
1055 PDLByteCode::PDLByteCode(
1056 ModuleOp module
, SmallVector
<std::unique_ptr
<PDLPatternConfigSet
>> configs
,
1057 const DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
,
1058 llvm::StringMap
<PDLConstraintFunction
> constraintFns
,
1059 llvm::StringMap
<PDLRewriteFunction
> rewriteFns
)
1060 : configs(std::move(configs
)) {
1061 Generator
generator(module
.getContext(), uniquedData
, matcherByteCode
,
1062 rewriterByteCode
, patterns
, maxValueMemoryIndex
,
1063 maxOpRangeCount
, maxTypeRangeCount
, maxValueRangeCount
,
1064 maxLoopLevel
, constraintFns
, rewriteFns
, configMap
);
1065 generator
.generate(module
);
1067 // Initialize the external functions.
1068 for (auto &it
: constraintFns
)
1069 constraintFunctions
.push_back(std::move(it
.second
));
1070 for (auto &it
: rewriteFns
)
1071 rewriteFunctions
.push_back(std::move(it
.second
));
1074 /// Initialize the given state such that it can be used to execute the current
1076 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState
&state
) const {
1077 state
.memory
.resize(maxValueMemoryIndex
, nullptr);
1078 state
.opRangeMemory
.resize(maxOpRangeCount
);
1079 state
.typeRangeMemory
.resize(maxTypeRangeCount
, TypeRange());
1080 state
.valueRangeMemory
.resize(maxValueRangeCount
, ValueRange());
1081 state
.loopIndex
.resize(maxLoopLevel
, 0);
1082 state
.currentPatternBenefits
.reserve(patterns
.size());
1083 for (const PDLByteCodePattern
&pattern
: patterns
)
1084 state
.currentPatternBenefits
.push_back(pattern
.getBenefit());
1087 //===----------------------------------------------------------------------===//
1088 // ByteCode Execution
1091 /// This class is an instantiation of the PDLResultList that provides access to
1092 /// the returned results. This API is not on `PDLResultList` to avoid
1093 /// overexposing access to information specific solely to the ByteCode.
1094 class ByteCodeRewriteResultList
: public PDLResultList
{
1096 ByteCodeRewriteResultList(unsigned maxNumResults
)
1097 : PDLResultList(maxNumResults
) {}
1099 /// Return the list of PDL results.
1100 MutableArrayRef
<PDLValue
> getResults() { return results
; }
1102 /// Return the type ranges allocated by this list.
1103 MutableArrayRef
<llvm::OwningArrayRef
<Type
>> getAllocatedTypeRanges() {
1104 return allocatedTypeRanges
;
1107 /// Return the value ranges allocated by this list.
1108 MutableArrayRef
<llvm::OwningArrayRef
<Value
>> getAllocatedValueRanges() {
1109 return allocatedValueRanges
;
1113 /// This class provides support for executing a bytecode stream.
1114 class ByteCodeExecutor
{
1117 const ByteCodeField
*curCodeIt
, MutableArrayRef
<const void *> memory
,
1118 MutableArrayRef
<llvm::OwningArrayRef
<Operation
*>> opRangeMemory
,
1119 MutableArrayRef
<TypeRange
> typeRangeMemory
,
1120 std::vector
<llvm::OwningArrayRef
<Type
>> &allocatedTypeRangeMemory
,
1121 MutableArrayRef
<ValueRange
> valueRangeMemory
,
1122 std::vector
<llvm::OwningArrayRef
<Value
>> &allocatedValueRangeMemory
,
1123 MutableArrayRef
<unsigned> loopIndex
, ArrayRef
<const void *> uniquedMemory
,
1124 ArrayRef
<ByteCodeField
> code
,
1125 ArrayRef
<PatternBenefit
> currentPatternBenefits
,
1126 ArrayRef
<PDLByteCodePattern
> patterns
,
1127 ArrayRef
<PDLConstraintFunction
> constraintFunctions
,
1128 ArrayRef
<PDLRewriteFunction
> rewriteFunctions
)
1129 : curCodeIt(curCodeIt
), memory(memory
), opRangeMemory(opRangeMemory
),
1130 typeRangeMemory(typeRangeMemory
),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory
),
1132 valueRangeMemory(valueRangeMemory
),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory
),
1134 loopIndex(loopIndex
), uniquedMemory(uniquedMemory
), code(code
),
1135 currentPatternBenefits(currentPatternBenefits
), patterns(patterns
),
1136 constraintFunctions(constraintFunctions
),
1137 rewriteFunctions(rewriteFunctions
) {}
1139 /// Start executing the code at the current bytecode index. `matches` is an
1140 /// optional field provided when this function is executed in a matching
1143 execute(PatternRewriter
&rewriter
,
1144 SmallVectorImpl
<PDLByteCode::MatchResult
> *matches
= nullptr,
1145 std::optional
<Location
> mainRewriteLoc
= {});
1148 /// Internal implementation of executing each of the bytecode commands.
1149 void executeApplyConstraint(PatternRewriter
&rewriter
);
1150 LogicalResult
executeApplyRewrite(PatternRewriter
&rewriter
);
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1160 void executeCreateOperation(PatternRewriter
&rewriter
,
1161 Location mainRewriteLoc
);
1162 template <typename T
>
1163 void executeDynamicCreateRange(StringRef type
);
1164 void executeEraseOp(PatternRewriter
&rewriter
);
1165 template <typename T
, typename Range
, PDLValue::Kind kind
>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(unsigned index
);
1173 void executeGetOperands();
1174 void executeGetResult(unsigned index
);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1180 void executeRecordMatch(PatternRewriter
&rewriter
,
1181 SmallVectorImpl
<PDLByteCode::MatchResult
> &matches
);
1182 void executeReplaceOp(PatternRewriter
&rewriter
);
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList
&results
,
1190 unsigned numResults
,
1191 LogicalResult
&rewriteResult
);
1193 /// Pushes a code iterator to the stack.
1194 void pushCodeIt(const ByteCodeField
*it
) { resumeCodeIt
.push_back(it
); }
1196 /// Pops a code iterator from the stack, returning true on success.
1198 assert(!resumeCodeIt
.empty() && "attempt to pop code off empty stack");
1199 curCodeIt
= resumeCodeIt
.back();
1200 resumeCodeIt
.pop_back();
1203 /// Return the bytecode iterator at the start of the current op code.
1204 const ByteCodeField
*getPrevCodeIt() const {
1206 // Account for the op code and the Location stored inline.
1207 return curCodeIt
- 1 - sizeof(const void *) / sizeof(ByteCodeField
);
1210 // Account for the op code only.
1211 return curCodeIt
- 1;
1214 /// Read a value from the bytecode buffer, optionally skipping a certain
1215 /// number of prefix values. These methods always update the buffer to point
1216 /// to the next field after the read data.
1217 template <typename T
= ByteCodeField
>
1218 T
read(size_t skipN
= 0) {
1220 return readImpl
<T
>();
1222 ByteCodeField
read(size_t skipN
= 0) { return read
<ByteCodeField
>(skipN
); }
1224 /// Read a list of values from the bytecode buffer.
1225 template <typename ValueT
, typename T
>
1226 void readList(SmallVectorImpl
<T
> &list
) {
1228 for (unsigned i
= 0, e
= read(); i
!= e
; ++i
)
1229 list
.push_back(read
<ValueT
>());
1232 /// Read a list of values from the bytecode buffer. The values may be encoded
1233 /// either as a single element or a range of elements.
1234 void readList(SmallVectorImpl
<Type
> &list
) {
1235 for (unsigned i
= 0, e
= read(); i
!= e
; ++i
) {
1236 if (read
<PDLValue::Kind
>() == PDLValue::Kind::Type
) {
1237 list
.push_back(read
<Type
>());
1239 TypeRange
*values
= read
<TypeRange
*>();
1240 list
.append(values
->begin(), values
->end());
1244 void readList(SmallVectorImpl
<Value
> &list
) {
1245 for (unsigned i
= 0, e
= read(); i
!= e
; ++i
) {
1246 if (read
<PDLValue::Kind
>() == PDLValue::Kind::Value
) {
1247 list
.push_back(read
<Value
>());
1249 ValueRange
*values
= read
<ValueRange
*>();
1250 list
.append(values
->begin(), values
->end());
1255 /// Read a value stored inline as a pointer.
1256 template <typename T
>
1257 std::enable_if_t
<llvm::is_detected
<has_pointer_traits
, T
>::value
, T
>
1259 const void *pointer
;
1260 std::memcpy(&pointer
, curCodeIt
, sizeof(const void *));
1261 curCodeIt
+= sizeof(const void *) / sizeof(ByteCodeField
);
1262 return T::getFromOpaquePointer(pointer
);
1265 void skip(size_t skipN
) { curCodeIt
+= skipN
; }
1267 /// Jump to a specific successor based on a predicate value.
1268 void selectJump(bool isTrue
) { selectJump(size_t(isTrue
? 0 : 1)); }
1269 /// Jump to a specific successor based on a destination index.
1270 void selectJump(size_t destIndex
) {
1271 curCodeIt
= &code
[read
<ByteCodeAddr
>(destIndex
* 2)];
1274 /// Handle a switch operation with the provided value and cases.
1275 template <typename T
, typename RangeT
, typename Comparator
= std::equal_to
<T
>>
1276 void handleSwitch(const T
&value
, RangeT
&&cases
, Comparator cmp
= {}) {
1278 llvm::dbgs() << " * Value: " << value
<< "\n"
1280 llvm::interleaveComma(cases
, llvm::dbgs());
1281 llvm::dbgs() << "\n";
1284 // Check to see if the attribute value is within the case list. Jump to
1285 // the correct successor index based on the result.
1286 for (auto it
= cases
.begin(), e
= cases
.end(); it
!= e
; ++it
)
1287 if (cmp(*it
, value
))
1288 return selectJump(size_t((it
- cases
.begin()) + 1));
1289 selectJump(size_t(0));
1292 /// Store a pointer to memory.
1293 void storeToMemory(unsigned index
, const void *value
) {
1294 memory
[index
] = value
;
1297 /// Store a value to memory as an opaque pointer.
1298 template <typename T
>
1299 std::enable_if_t
<llvm::is_detected
<has_pointer_traits
, T
>::value
>
1300 storeToMemory(unsigned index
, T value
) {
1301 memory
[index
] = value
.getAsOpaquePointer();
1304 /// Internal implementation of reading various data types from the bytecode
1306 template <typename T
>
1307 const void *readFromMemory() {
1308 size_t index
= *curCodeIt
++;
1310 // If this type is an SSA value, it can only be stored in non-const memory.
1311 if (llvm::is_one_of
<T
, Operation
*, TypeRange
*, ValueRange
*,
1313 index
< memory
.size())
1314 return memory
[index
];
1316 // Otherwise, if this index is not inbounds it is uniqued.
1317 return uniquedMemory
[index
- memory
.size()];
1319 template <typename T
>
1320 std::enable_if_t
<std::is_pointer
<T
>::value
, T
> readImpl() {
1321 return reinterpret_cast<T
>(const_cast<void *>(readFromMemory
<T
>()));
1323 template <typename T
>
1324 std::enable_if_t
<std::is_class
<T
>::value
&& !std::is_same
<PDLValue
, T
>::value
,
1327 return T(T::getFromOpaquePointer(readFromMemory
<T
>()));
1329 template <typename T
>
1330 std::enable_if_t
<std::is_same
<PDLValue
, T
>::value
, T
> readImpl() {
1331 switch (read
<PDLValue::Kind
>()) {
1332 case PDLValue::Kind::Attribute
:
1333 return read
<Attribute
>();
1334 case PDLValue::Kind::Operation
:
1335 return read
<Operation
*>();
1336 case PDLValue::Kind::Type
:
1337 return read
<Type
>();
1338 case PDLValue::Kind::Value
:
1339 return read
<Value
>();
1340 case PDLValue::Kind::TypeRange
:
1341 return read
<TypeRange
*>();
1342 case PDLValue::Kind::ValueRange
:
1343 return read
<ValueRange
*>();
1345 llvm_unreachable("unhandled PDLValue::Kind");
1347 template <typename T
>
1348 std::enable_if_t
<std::is_same
<T
, ByteCodeAddr
>::value
, T
> readImpl() {
1349 static_assert((sizeof(ByteCodeAddr
) / sizeof(ByteCodeField
)) == 2,
1350 "unexpected ByteCode address size");
1351 ByteCodeAddr result
;
1352 std::memcpy(&result
, curCodeIt
, sizeof(ByteCodeAddr
));
1356 template <typename T
>
1357 std::enable_if_t
<std::is_same
<T
, ByteCodeField
>::value
, T
> readImpl() {
1358 return *curCodeIt
++;
1360 template <typename T
>
1361 std::enable_if_t
<std::is_same
<T
, PDLValue::Kind
>::value
, T
> readImpl() {
1362 return static_cast<PDLValue::Kind
>(readImpl
<ByteCodeField
>());
1365 /// Assign the given range to the given memory index. This allocates a new
1366 /// range object if necessary.
1367 template <typename RangeT
, typename T
= llvm::detail::ValueOfRange
<RangeT
>>
1368 void assignRangeToMemory(RangeT
&&range
, unsigned memIndex
,
1369 unsigned rangeIndex
) {
1370 // Utility functor used to type-erase the assignment.
1371 auto assignRange
= [&](auto &allocatedRangeMemory
, auto &rangeMemory
) {
1372 // If the input range is empty, we don't need to allocate anything.
1373 if (range
.empty()) {
1374 rangeMemory
[rangeIndex
] = {};
1376 // Allocate a buffer for this type range.
1377 llvm::OwningArrayRef
<T
> storage(llvm::size(range
));
1378 llvm::copy(range
, storage
.begin());
1380 // Assign this to the range slot and use the range as the value for the
1382 allocatedRangeMemory
.emplace_back(std::move(storage
));
1383 rangeMemory
[rangeIndex
] = allocatedRangeMemory
.back();
1385 memory
[memIndex
] = &rangeMemory
[rangeIndex
];
1388 // Dispatch based on the concrete range type.
1389 if constexpr (std::is_same_v
<T
, Type
>) {
1390 return assignRange(allocatedTypeRangeMemory
, typeRangeMemory
);
1391 } else if constexpr (std::is_same_v
<T
, Value
>) {
1392 return assignRange(allocatedValueRangeMemory
, valueRangeMemory
);
1394 llvm_unreachable("unhandled range type");
1398 /// The underlying bytecode buffer.
1399 const ByteCodeField
*curCodeIt
;
1401 /// The stack of bytecode positions at which to resume operation.
1402 SmallVector
<const ByteCodeField
*> resumeCodeIt
;
1404 /// The current execution memory.
1405 MutableArrayRef
<const void *> memory
;
1406 MutableArrayRef
<OwningOpRange
> opRangeMemory
;
1407 MutableArrayRef
<TypeRange
> typeRangeMemory
;
1408 std::vector
<llvm::OwningArrayRef
<Type
>> &allocatedTypeRangeMemory
;
1409 MutableArrayRef
<ValueRange
> valueRangeMemory
;
1410 std::vector
<llvm::OwningArrayRef
<Value
>> &allocatedValueRangeMemory
;
1412 /// The current loop indices.
1413 MutableArrayRef
<unsigned> loopIndex
;
1415 /// References to ByteCode data necessary for execution.
1416 ArrayRef
<const void *> uniquedMemory
;
1417 ArrayRef
<ByteCodeField
> code
;
1418 ArrayRef
<PatternBenefit
> currentPatternBenefits
;
1419 ArrayRef
<PDLByteCodePattern
> patterns
;
1420 ArrayRef
<PDLConstraintFunction
> constraintFunctions
;
1421 ArrayRef
<PDLRewriteFunction
> rewriteFunctions
;
1425 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter
&rewriter
) {
1426 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1427 ByteCodeField fun_idx
= read();
1428 SmallVector
<PDLValue
, 16> args
;
1429 readList
<PDLValue
>(args
);
1432 llvm::dbgs() << " * Arguments: ";
1433 llvm::interleaveComma(args
, llvm::dbgs());
1434 llvm::dbgs() << "\n";
1437 ByteCodeField isNegated
= read();
1439 llvm::dbgs() << " * isNegated: " << isNegated
<< "\n";
1440 llvm::interleaveComma(args
, llvm::dbgs());
1443 ByteCodeField numResults
= read();
1444 const PDLRewriteFunction
&constraintFn
= constraintFunctions
[fun_idx
];
1445 ByteCodeRewriteResultList
results(numResults
);
1446 LogicalResult rewriteResult
= constraintFn(rewriter
, results
, args
);
1447 [[maybe_unused
]] ArrayRef
<PDLValue
> constraintResults
= results
.getResults();
1449 if (succeeded(rewriteResult
)) {
1450 llvm::dbgs() << " * Constraint succeeded\n";
1451 llvm::dbgs() << " * Results: ";
1452 llvm::interleaveComma(constraintResults
, llvm::dbgs());
1453 llvm::dbgs() << "\n";
1455 llvm::dbgs() << " * Constraint failed\n";
1458 assert((failed(rewriteResult
) || constraintResults
.size() == numResults
) &&
1459 "native PDL rewrite function succeeded but returned "
1460 "unexpected number of results");
1461 processNativeFunResults(results
, numResults
, rewriteResult
);
1463 // Depending on the constraint jump to the proper destination.
1464 selectJump(isNegated
!= succeeded(rewriteResult
));
1467 LogicalResult
ByteCodeExecutor::executeApplyRewrite(PatternRewriter
&rewriter
) {
1468 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1469 const PDLRewriteFunction
&rewriteFn
= rewriteFunctions
[read()];
1470 SmallVector
<PDLValue
, 16> args
;
1471 readList
<PDLValue
>(args
);
1474 llvm::dbgs() << " * Arguments: ";
1475 llvm::interleaveComma(args
, llvm::dbgs());
1478 // Execute the rewrite function.
1479 ByteCodeField numResults
= read();
1480 ByteCodeRewriteResultList
results(numResults
);
1481 LogicalResult rewriteResult
= rewriteFn(rewriter
, results
, args
);
1483 assert(results
.getResults().size() == numResults
&&
1484 "native PDL rewrite function returned unexpected number of results");
1486 processNativeFunResults(results
, numResults
, rewriteResult
);
1488 if (failed(rewriteResult
)) {
1489 LLVM_DEBUG(llvm::dbgs() << " - Failed");
1495 void ByteCodeExecutor::processNativeFunResults(
1496 ByteCodeRewriteResultList
&results
, unsigned numResults
,
1497 LogicalResult
&rewriteResult
) {
1498 // Store the results in the bytecode memory or handle missing results on
1500 for (unsigned resultIdx
= 0; resultIdx
< numResults
; resultIdx
++) {
1501 PDLValue::Kind resultKind
= read
<PDLValue::Kind
>();
1503 // Skip the according number of values on the buffer on failure and exit
1504 // early as there are no results to process.
1505 if (failed(rewriteResult
)) {
1506 if (resultKind
== PDLValue::Kind::TypeRange
||
1507 resultKind
== PDLValue::Kind::ValueRange
) {
1514 PDLValue result
= results
.getResults()[resultIdx
];
1515 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result
<< "\n");
1516 assert(result
.getKind() == resultKind
&&
1517 "native PDL rewrite function returned an unexpected type of "
1519 // If the result is a range, we need to copy it over to the bytecodes
1521 if (std::optional
<TypeRange
> typeRange
= result
.dyn_cast
<TypeRange
>()) {
1522 unsigned rangeIndex
= read();
1523 typeRangeMemory
[rangeIndex
] = *typeRange
;
1524 memory
[read()] = &typeRangeMemory
[rangeIndex
];
1525 } else if (std::optional
<ValueRange
> valueRange
=
1526 result
.dyn_cast
<ValueRange
>()) {
1527 unsigned rangeIndex
= read();
1528 valueRangeMemory
[rangeIndex
] = *valueRange
;
1529 memory
[read()] = &valueRangeMemory
[rangeIndex
];
1531 memory
[read()] = result
.getAsOpaquePointer();
1535 // Copy over any underlying storage allocated for result ranges.
1536 for (auto &it
: results
.getAllocatedTypeRanges())
1537 allocatedTypeRangeMemory
.push_back(std::move(it
));
1538 for (auto &it
: results
.getAllocatedValueRanges())
1539 allocatedValueRangeMemory
.push_back(std::move(it
));
1542 void ByteCodeExecutor::executeAreEqual() {
1543 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1544 const void *lhs
= read
<const void *>();
1545 const void *rhs
= read
<const void *>();
1547 LLVM_DEBUG(llvm::dbgs() << " * " << lhs
<< " == " << rhs
<< "\n");
1548 selectJump(lhs
== rhs
);
1551 void ByteCodeExecutor::executeAreRangesEqual() {
1552 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1553 PDLValue::Kind valueKind
= read
<PDLValue::Kind
>();
1554 const void *lhs
= read
<const void *>();
1555 const void *rhs
= read
<const void *>();
1557 switch (valueKind
) {
1558 case PDLValue::Kind::TypeRange
: {
1559 const TypeRange
*lhsRange
= reinterpret_cast<const TypeRange
*>(lhs
);
1560 const TypeRange
*rhsRange
= reinterpret_cast<const TypeRange
*>(rhs
);
1561 LLVM_DEBUG(llvm::dbgs() << " * " << lhs
<< " == " << rhs
<< "\n\n");
1562 selectJump(*lhsRange
== *rhsRange
);
1565 case PDLValue::Kind::ValueRange
: {
1566 const auto *lhsRange
= reinterpret_cast<const ValueRange
*>(lhs
);
1567 const auto *rhsRange
= reinterpret_cast<const ValueRange
*>(rhs
);
1568 LLVM_DEBUG(llvm::dbgs() << " * " << lhs
<< " == " << rhs
<< "\n\n");
1569 selectJump(*lhsRange
== *rhsRange
);
1573 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1577 void ByteCodeExecutor::executeBranch() {
1578 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1579 curCodeIt
= &code
[read
<ByteCodeAddr
>()];
1582 void ByteCodeExecutor::executeCheckOperandCount() {
1583 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1584 Operation
*op
= read
<Operation
*>();
1585 uint32_t expectedCount
= read
<uint32_t>();
1586 bool compareAtLeast
= read();
1588 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op
->getNumOperands() << "\n"
1589 << " * Expected: " << expectedCount
<< "\n"
1590 << " * Comparator: "
1591 << (compareAtLeast
? ">=" : "==") << "\n");
1593 selectJump(op
->getNumOperands() >= expectedCount
);
1595 selectJump(op
->getNumOperands() == expectedCount
);
1598 void ByteCodeExecutor::executeCheckOperationName() {
1599 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1600 Operation
*op
= read
<Operation
*>();
1601 OperationName expectedName
= read
<OperationName
>();
1603 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op
->getName() << "\"\n"
1604 << " * Expected: \"" << expectedName
<< "\"\n");
1605 selectJump(op
->getName() == expectedName
);
1608 void ByteCodeExecutor::executeCheckResultCount() {
1609 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1610 Operation
*op
= read
<Operation
*>();
1611 uint32_t expectedCount
= read
<uint32_t>();
1612 bool compareAtLeast
= read();
1614 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op
->getNumResults() << "\n"
1615 << " * Expected: " << expectedCount
<< "\n"
1616 << " * Comparator: "
1617 << (compareAtLeast
? ">=" : "==") << "\n");
1619 selectJump(op
->getNumResults() >= expectedCount
);
1621 selectJump(op
->getNumResults() == expectedCount
);
1624 void ByteCodeExecutor::executeCheckTypes() {
1625 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1626 TypeRange
*lhs
= read
<TypeRange
*>();
1627 Attribute rhs
= read
<Attribute
>();
1628 LLVM_DEBUG(llvm::dbgs() << " * " << lhs
<< " == " << rhs
<< "\n\n");
1630 selectJump(*lhs
== cast
<ArrayAttr
>(rhs
).getAsValueRange
<TypeAttr
>());
1633 void ByteCodeExecutor::executeContinue() {
1634 ByteCodeField level
= read();
1635 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1636 << " * Level: " << level
<< "\n");
1641 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1643 unsigned memIndex
= read();
1644 unsigned rangeIndex
= read();
1645 ArrayAttr typesAttr
= cast
<ArrayAttr
>(read
<Attribute
>());
1647 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr
<< "\n\n");
1648 assignRangeToMemory(typesAttr
.getAsValueRange
<TypeAttr
>(), memIndex
,
1652 void ByteCodeExecutor::executeCreateOperation(PatternRewriter
&rewriter
,
1653 Location mainRewriteLoc
) {
1654 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1656 unsigned memIndex
= read();
1657 OperationState
state(mainRewriteLoc
, read
<OperationName
>());
1658 readList(state
.operands
);
1659 for (unsigned i
= 0, e
= read(); i
!= e
; ++i
) {
1660 StringAttr name
= read
<StringAttr
>();
1661 if (Attribute attr
= read
<Attribute
>())
1662 state
.addAttribute(name
, attr
);
1665 // Read in the result types. If the "size" is the sentinel value, this
1666 // indicates that the result types should be inferred.
1667 unsigned numResults
= read();
1668 if (numResults
== kInferTypesMarker
) {
1669 InferTypeOpInterface::Concept
*inferInterface
=
1670 state
.name
.getInterface
<InferTypeOpInterface
>();
1671 assert(inferInterface
&&
1672 "expected operation to provide InferTypeOpInterface");
1674 // TODO: Handle failure.
1675 if (failed(inferInterface
->inferReturnTypes(
1676 state
.getContext(), state
.location
, state
.operands
,
1677 state
.attributes
.getDictionary(state
.getContext()),
1678 state
.getRawProperties(), state
.regions
, state
.types
)))
1681 // Otherwise, this is a fixed number of results.
1682 for (unsigned i
= 0; i
!= numResults
; ++i
) {
1683 if (read
<PDLValue::Kind
>() == PDLValue::Kind::Type
) {
1684 state
.types
.push_back(read
<Type
>());
1686 TypeRange
*resultTypes
= read
<TypeRange
*>();
1687 state
.types
.append(resultTypes
->begin(), resultTypes
->end());
1692 Operation
*resultOp
= rewriter
.create(state
);
1693 memory
[memIndex
] = resultOp
;
1696 llvm::dbgs() << " * Attributes: "
1697 << state
.attributes
.getDictionary(state
.getContext())
1698 << "\n * Operands: ";
1699 llvm::interleaveComma(state
.operands
, llvm::dbgs());
1700 llvm::dbgs() << "\n * Result Types: ";
1701 llvm::interleaveComma(state
.types
, llvm::dbgs());
1702 llvm::dbgs() << "\n * Result: " << *resultOp
<< "\n";
1706 template <typename T
>
1707 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type
) {
1708 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type
<< "Range:\n");
1709 unsigned memIndex
= read();
1710 unsigned rangeIndex
= read();
1711 SmallVector
<T
> values
;
1715 llvm::dbgs() << "\n * " << type
<< "s: ";
1716 llvm::interleaveComma(values
, llvm::dbgs());
1717 llvm::dbgs() << "\n";
1720 assignRangeToMemory(values
, memIndex
, rangeIndex
);
1723 void ByteCodeExecutor::executeEraseOp(PatternRewriter
&rewriter
) {
1724 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1725 Operation
*op
= read
<Operation
*>();
1727 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n");
1728 rewriter
.eraseOp(op
);
1731 template <typename T
, typename Range
, PDLValue::Kind kind
>
1732 void ByteCodeExecutor::executeExtract() {
1733 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind
<< ":\n");
1734 Range
*range
= read
<Range
*>();
1735 unsigned index
= read
<uint32_t>();
1736 unsigned memIndex
= read();
1739 memory
[memIndex
] = nullptr;
1743 T result
= index
< range
->size() ? (*range
)[index
] : T();
1744 LLVM_DEBUG(llvm::dbgs() << " * " << kind
<< "s(" << range
->size() << ")\n"
1745 << " * Index: " << index
<< "\n"
1746 << " * Result: " << result
<< "\n");
1747 storeToMemory(memIndex
, result
);
1750 void ByteCodeExecutor::executeFinalize() {
1751 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1754 void ByteCodeExecutor::executeForEach() {
1755 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1756 const ByteCodeField
*prevCodeIt
= getPrevCodeIt();
1757 unsigned rangeIndex
= read();
1758 unsigned memIndex
= read();
1759 const void *value
= nullptr;
1761 switch (read
<PDLValue::Kind
>()) {
1762 case PDLValue::Kind::Operation
: {
1763 unsigned &index
= loopIndex
[read()];
1764 ArrayRef
<Operation
*> array
= opRangeMemory
[rangeIndex
];
1765 assert(index
<= array
.size() && "iterated past the end");
1766 if (index
< array
.size()) {
1767 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array
[index
] << "\n");
1768 value
= array
[index
];
1772 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1774 selectJump(size_t(0));
1778 llvm_unreachable("unexpected `ForEach` value kind");
1781 // Store the iterate value and the stack address.
1782 memory
[memIndex
] = value
;
1783 pushCodeIt(prevCodeIt
);
1785 // Skip over the successor (we will enter the body of the loop).
1786 read
<ByteCodeAddr
>();
1789 void ByteCodeExecutor::executeGetAttribute() {
1790 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1791 unsigned memIndex
= read();
1792 Operation
*op
= read
<Operation
*>();
1793 StringAttr attrName
= read
<StringAttr
>();
1794 Attribute attr
= op
->getAttr(attrName
);
1796 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n"
1797 << " * Attribute: " << attrName
<< "\n"
1798 << " * Result: " << attr
<< "\n");
1799 memory
[memIndex
] = attr
.getAsOpaquePointer();
1802 void ByteCodeExecutor::executeGetAttributeType() {
1803 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1804 unsigned memIndex
= read();
1805 Attribute attr
= read
<Attribute
>();
1807 if (auto typedAttr
= dyn_cast
<TypedAttr
>(attr
))
1808 type
= typedAttr
.getType();
1810 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr
<< "\n"
1811 << " * Result: " << type
<< "\n");
1812 memory
[memIndex
] = type
.getAsOpaquePointer();
1815 void ByteCodeExecutor::executeGetDefiningOp() {
1816 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1817 unsigned memIndex
= read();
1818 Operation
*op
= nullptr;
1819 if (read
<PDLValue::Kind
>() == PDLValue::Kind::Value
) {
1820 Value value
= read
<Value
>();
1822 op
= value
.getDefiningOp();
1823 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value
<< "\n");
1825 ValueRange
*values
= read
<ValueRange
*>();
1826 if (values
&& !values
->empty()) {
1827 op
= values
->front().getDefiningOp();
1829 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values
<< "\n");
1832 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op
<< "\n");
1833 memory
[memIndex
] = op
;
1836 void ByteCodeExecutor::executeGetOperand(unsigned index
) {
1837 Operation
*op
= read
<Operation
*>();
1838 unsigned memIndex
= read();
1840 index
< op
->getNumOperands() ? op
->getOperand(index
) : Value();
1842 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n"
1843 << " * Index: " << index
<< "\n"
1844 << " * Result: " << operand
<< "\n");
1845 memory
[memIndex
] = operand
.getAsOpaquePointer();
1848 /// This function is the internal implementation of `GetResults` and
1849 /// `GetOperands` that provides support for extracting a value range from the
1850 /// given operation.
1851 template <template <typename
> class AttrSizedSegmentsT
, typename RangeT
>
1853 executeGetOperandsResults(RangeT values
, Operation
*op
, unsigned index
,
1854 ByteCodeField rangeIndex
, StringRef attrSizedSegments
,
1855 MutableArrayRef
<ValueRange
> valueRangeMemory
) {
1856 // Check for the sentinel index that signals that all values should be
1858 if (index
== std::numeric_limits
<uint32_t>::max()) {
1859 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1860 // `values` is already the full value range.
1862 // Otherwise, check to see if this operation uses AttrSizedSegments.
1863 } else if (op
->hasTrait
<AttrSizedSegmentsT
>()) {
1864 LLVM_DEBUG(llvm::dbgs()
1865 << " * Extracting values from `" << attrSizedSegments
<< "`\n");
1867 auto segmentAttr
= op
->getAttrOfType
<DenseI32ArrayAttr
>(attrSizedSegments
);
1868 if (!segmentAttr
|| segmentAttr
.asArrayRef().size() <= index
)
1871 ArrayRef
<int32_t> segments
= segmentAttr
;
1872 unsigned startIndex
=
1873 std::accumulate(segments
.begin(), segments
.begin() + index
, 0);
1874 values
= values
.slice(startIndex
, *std::next(segments
.begin(), index
));
1876 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex
<< ", "
1877 << *std::next(segments
.begin(), index
) << "]\n");
1879 // Otherwise, assume this is the last operand group of the operation.
1880 // FIXME: We currently don't support operations with
1881 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1882 // have a way to detect it's presence.
1883 } else if (values
.size() >= index
) {
1884 LLVM_DEBUG(llvm::dbgs()
1885 << " * Treating values as trailing variadic range\n");
1886 values
= values
.drop_front(index
);
1888 // If we couldn't detect a way to compute the values, bail out.
1893 // If the range index is valid, we are returning a range.
1894 if (rangeIndex
!= std::numeric_limits
<ByteCodeField
>::max()) {
1895 valueRangeMemory
[rangeIndex
] = values
;
1896 return &valueRangeMemory
[rangeIndex
];
1899 // If a range index wasn't provided, the range is required to be non-variadic.
1900 return values
.size() != 1 ? nullptr : values
.front().getAsOpaquePointer();
1903 void ByteCodeExecutor::executeGetOperands() {
1904 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1905 unsigned index
= read
<uint32_t>();
1906 Operation
*op
= read
<Operation
*>();
1907 ByteCodeField rangeIndex
= read();
1909 void *result
= executeGetOperandsResults
<OpTrait::AttrSizedOperandSegments
>(
1910 op
->getOperands(), op
, index
, rangeIndex
, "operandSegmentSizes",
1913 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1914 memory
[read()] = result
;
1917 void ByteCodeExecutor::executeGetResult(unsigned index
) {
1918 Operation
*op
= read
<Operation
*>();
1919 unsigned memIndex
= read();
1921 index
< op
->getNumResults() ? op
->getResult(index
) : OpResult();
1923 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n"
1924 << " * Index: " << index
<< "\n"
1925 << " * Result: " << result
<< "\n");
1926 memory
[memIndex
] = result
.getAsOpaquePointer();
1929 void ByteCodeExecutor::executeGetResults() {
1930 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1931 unsigned index
= read
<uint32_t>();
1932 Operation
*op
= read
<Operation
*>();
1933 ByteCodeField rangeIndex
= read();
1935 void *result
= executeGetOperandsResults
<OpTrait::AttrSizedResultSegments
>(
1936 op
->getResults(), op
, index
, rangeIndex
, "resultSegmentSizes",
1939 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1940 memory
[read()] = result
;
1943 void ByteCodeExecutor::executeGetUsers() {
1944 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1945 unsigned memIndex
= read();
1946 unsigned rangeIndex
= read();
1947 OwningOpRange
&range
= opRangeMemory
[rangeIndex
];
1948 memory
[memIndex
] = &range
;
1950 range
= OwningOpRange();
1951 if (read
<PDLValue::Kind
>() == PDLValue::Kind::Value
) {
1953 Value value
= read
<Value
>();
1956 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value
<< "\n");
1958 // Extract the users of a single value.
1959 range
= OwningOpRange(std::distance(value
.user_begin(), value
.user_end()));
1960 llvm::copy(value
.getUsers(), range
.begin());
1962 // Read a range of values.
1963 ValueRange
*values
= read
<ValueRange
*>();
1967 llvm::dbgs() << " * Values (" << values
->size() << "): ";
1968 llvm::interleaveComma(*values
, llvm::dbgs());
1969 llvm::dbgs() << "\n";
1972 // Extract all the users of a range of values.
1973 SmallVector
<Operation
*> users
;
1974 for (Value value
: *values
)
1975 users
.append(value
.user_begin(), value
.user_end());
1976 range
= OwningOpRange(users
.size());
1977 llvm::copy(users
, range
.begin());
1980 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range
.size() << " operations\n");
1983 void ByteCodeExecutor::executeGetValueType() {
1984 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1985 unsigned memIndex
= read();
1986 Value value
= read
<Value
>();
1987 Type type
= value
? value
.getType() : Type();
1989 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value
<< "\n"
1990 << " * Result: " << type
<< "\n");
1991 memory
[memIndex
] = type
.getAsOpaquePointer();
1994 void ByteCodeExecutor::executeGetValueRangeTypes() {
1995 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1996 unsigned memIndex
= read();
1997 unsigned rangeIndex
= read();
1998 ValueRange
*values
= read
<ValueRange
*>();
2000 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2001 memory
[memIndex
] = nullptr;
2006 llvm::dbgs() << " * Values (" << values
->size() << "): ";
2007 llvm::interleaveComma(*values
, llvm::dbgs());
2008 llvm::dbgs() << "\n * Result: ";
2009 llvm::interleaveComma(values
->getType(), llvm::dbgs());
2010 llvm::dbgs() << "\n";
2012 typeRangeMemory
[rangeIndex
] = values
->getType();
2013 memory
[memIndex
] = &typeRangeMemory
[rangeIndex
];
2016 void ByteCodeExecutor::executeIsNotNull() {
2017 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2018 const void *value
= read
<const void *>();
2020 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value
<< "\n");
2021 selectJump(value
!= nullptr);
2024 void ByteCodeExecutor::executeRecordMatch(
2025 PatternRewriter
&rewriter
,
2026 SmallVectorImpl
<PDLByteCode::MatchResult
> &matches
) {
2027 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2028 unsigned patternIndex
= read();
2029 PatternBenefit benefit
= currentPatternBenefits
[patternIndex
];
2030 const ByteCodeField
*dest
= &code
[read
<ByteCodeAddr
>()];
2032 // If the benefit of the pattern is impossible, skip the processing of the
2033 // rest of the pattern.
2034 if (benefit
.isImpossibleToMatch()) {
2035 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2040 // Create a fused location containing the locations of each of the
2041 // operations used in the match. This will be used as the location for
2042 // created operations during the rewrite that don't already have an
2043 // explicit location set.
2044 unsigned numMatchLocs
= read();
2045 SmallVector
<Location
, 4> matchLocs
;
2046 matchLocs
.reserve(numMatchLocs
);
2047 for (unsigned i
= 0; i
!= numMatchLocs
; ++i
)
2048 matchLocs
.push_back(read
<Operation
*>()->getLoc());
2049 Location matchLoc
= rewriter
.getFusedLoc(matchLocs
);
2051 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit
.getBenefit() << "\n"
2052 << " * Location: " << matchLoc
<< "\n");
2053 matches
.emplace_back(matchLoc
, patterns
[patternIndex
], benefit
);
2054 PDLByteCode::MatchResult
&match
= matches
.back();
2056 // Record all of the inputs to the match. If any of the inputs are ranges, we
2057 // will also need to remap the range pointer to memory stored in the match
2059 unsigned numInputs
= read();
2060 match
.values
.reserve(numInputs
);
2061 match
.typeRangeValues
.reserve(numInputs
);
2062 match
.valueRangeValues
.reserve(numInputs
);
2063 for (unsigned i
= 0; i
< numInputs
; ++i
) {
2064 switch (read
<PDLValue::Kind
>()) {
2065 case PDLValue::Kind::TypeRange
:
2066 match
.typeRangeValues
.push_back(*read
<TypeRange
*>());
2067 match
.values
.push_back(&match
.typeRangeValues
.back());
2069 case PDLValue::Kind::ValueRange
:
2070 match
.valueRangeValues
.push_back(*read
<ValueRange
*>());
2071 match
.values
.push_back(&match
.valueRangeValues
.back());
2074 match
.values
.push_back(read
<const void *>());
2081 void ByteCodeExecutor::executeReplaceOp(PatternRewriter
&rewriter
) {
2082 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2083 Operation
*op
= read
<Operation
*>();
2084 SmallVector
<Value
, 16> args
;
2088 llvm::dbgs() << " * Operation: " << *op
<< "\n"
2090 llvm::interleaveComma(args
, llvm::dbgs());
2091 llvm::dbgs() << "\n";
2093 rewriter
.replaceOp(op
, args
);
2096 void ByteCodeExecutor::executeSwitchAttribute() {
2097 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2098 Attribute value
= read
<Attribute
>();
2099 ArrayAttr cases
= read
<ArrayAttr
>();
2100 handleSwitch(value
, cases
);
2103 void ByteCodeExecutor::executeSwitchOperandCount() {
2104 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2105 Operation
*op
= read
<Operation
*>();
2106 auto cases
= read
<DenseIntOrFPElementsAttr
>().getValues
<uint32_t>();
2108 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n");
2109 handleSwitch(op
->getNumOperands(), cases
);
2112 void ByteCodeExecutor::executeSwitchOperationName() {
2113 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2114 OperationName value
= read
<Operation
*>()->getName();
2115 size_t caseCount
= read();
2117 // The operation names are stored in-line, so to print them out for
2118 // debugging purposes we need to read the array before executing the
2119 // switch so that we can display all of the possible values.
2121 const ByteCodeField
*prevCodeIt
= curCodeIt
;
2122 llvm::dbgs() << " * Value: " << value
<< "\n"
2124 llvm::interleaveComma(
2125 llvm::map_range(llvm::seq
<size_t>(0, caseCount
),
2126 [&](size_t) { return read
<OperationName
>(); }),
2128 llvm::dbgs() << "\n";
2129 curCodeIt
= prevCodeIt
;
2132 // Try to find the switch value within any of the cases.
2133 for (size_t i
= 0; i
!= caseCount
; ++i
) {
2134 if (read
<OperationName
>() == value
) {
2135 curCodeIt
+= (caseCount
- i
- 1);
2136 return selectJump(i
+ 1);
2139 selectJump(size_t(0));
2142 void ByteCodeExecutor::executeSwitchResultCount() {
2143 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2144 Operation
*op
= read
<Operation
*>();
2145 auto cases
= read
<DenseIntOrFPElementsAttr
>().getValues
<uint32_t>();
2147 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op
<< "\n");
2148 handleSwitch(op
->getNumResults(), cases
);
2151 void ByteCodeExecutor::executeSwitchType() {
2152 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2153 Type value
= read
<Type
>();
2154 auto cases
= read
<ArrayAttr
>().getAsValueRange
<TypeAttr
>();
2155 handleSwitch(value
, cases
);
2158 void ByteCodeExecutor::executeSwitchTypes() {
2159 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2160 TypeRange
*value
= read
<TypeRange
*>();
2161 auto cases
= read
<ArrayAttr
>().getAsRange
<ArrayAttr
>();
2163 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2164 return selectJump(size_t(0));
2166 handleSwitch(*value
, cases
, [](ArrayAttr caseValue
, const TypeRange
&value
) {
2167 return value
== caseValue
.getAsValueRange
<TypeAttr
>();
2172 ByteCodeExecutor::execute(PatternRewriter
&rewriter
,
2173 SmallVectorImpl
<PDLByteCode::MatchResult
> *matches
,
2174 std::optional
<Location
> mainRewriteLoc
) {
2176 // Print the location of the operation being executed.
2177 LLVM_DEBUG(llvm::dbgs() << readInline
<Location
>() << "\n");
2179 OpCode opCode
= static_cast<OpCode
>(read());
2181 case ApplyConstraint
:
2182 executeApplyConstraint(rewriter
);
2185 if (failed(executeApplyRewrite(rewriter
)))
2191 case AreRangesEqual
:
2192 executeAreRangesEqual();
2197 case CheckOperandCount
:
2198 executeCheckOperandCount();
2200 case CheckOperationName
:
2201 executeCheckOperationName();
2203 case CheckResultCount
:
2204 executeCheckResultCount();
2207 executeCheckTypes();
2212 case CreateConstantTypeRange
:
2213 executeCreateConstantTypeRange();
2215 case CreateOperation
:
2216 executeCreateOperation(rewriter
, *mainRewriteLoc
);
2218 case CreateDynamicTypeRange
:
2219 executeDynamicCreateRange
<Type
>("Type");
2221 case CreateDynamicValueRange
:
2222 executeDynamicCreateRange
<Value
>("Value");
2225 executeEraseOp(rewriter
);
2228 executeExtract
<Operation
*, OwningOpRange
, PDLValue::Kind::Operation
>();
2231 executeExtract
<Type
, TypeRange
, PDLValue::Kind::Type
>();
2234 executeExtract
<Value
, ValueRange
, PDLValue::Kind::Value
>();
2238 LLVM_DEBUG(llvm::dbgs() << "\n");
2244 executeGetAttribute();
2246 case GetAttributeType
:
2247 executeGetAttributeType();
2250 executeGetDefiningOp();
2256 unsigned index
= opCode
- GetOperand0
;
2257 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index
<< ":\n");
2258 executeGetOperand(index
);
2262 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2263 executeGetOperand(read
<uint32_t>());
2266 executeGetOperands();
2272 unsigned index
= opCode
- GetResult0
;
2273 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index
<< ":\n");
2274 executeGetResult(index
);
2278 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2279 executeGetResult(read
<uint32_t>());
2282 executeGetResults();
2288 executeGetValueType();
2290 case GetValueRangeTypes
:
2291 executeGetValueRangeTypes();
2298 "expected matches to be provided when executing the matcher");
2299 executeRecordMatch(rewriter
, *matches
);
2302 executeReplaceOp(rewriter
);
2304 case SwitchAttribute
:
2305 executeSwitchAttribute();
2307 case SwitchOperandCount
:
2308 executeSwitchOperandCount();
2310 case SwitchOperationName
:
2311 executeSwitchOperationName();
2313 case SwitchResultCount
:
2314 executeSwitchResultCount();
2317 executeSwitchType();
2320 executeSwitchTypes();
2323 LLVM_DEBUG(llvm::dbgs() << "\n");
2327 void PDLByteCode::match(Operation
*op
, PatternRewriter
&rewriter
,
2328 SmallVectorImpl
<MatchResult
> &matches
,
2329 PDLByteCodeMutableState
&state
) const {
2330 // The first memory slot is always the root operation.
2331 state
.memory
[0] = op
;
2333 // The matcher function always starts at code address 0.
2334 ByteCodeExecutor
executor(
2335 matcherByteCode
.data(), state
.memory
, state
.opRangeMemory
,
2336 state
.typeRangeMemory
, state
.allocatedTypeRangeMemory
,
2337 state
.valueRangeMemory
, state
.allocatedValueRangeMemory
, state
.loopIndex
,
2338 uniquedData
, matcherByteCode
, state
.currentPatternBenefits
, patterns
,
2339 constraintFunctions
, rewriteFunctions
);
2340 LogicalResult executeResult
= executor
.execute(rewriter
, &matches
);
2341 (void)executeResult
;
2342 assert(succeeded(executeResult
) && "unexpected matcher execution failure");
2344 // Order the found matches by benefit.
2345 std::stable_sort(matches
.begin(), matches
.end(),
2346 [](const MatchResult
&lhs
, const MatchResult
&rhs
) {
2347 return lhs
.benefit
> rhs
.benefit
;
2351 LogicalResult
PDLByteCode::rewrite(PatternRewriter
&rewriter
,
2352 const MatchResult
&match
,
2353 PDLByteCodeMutableState
&state
) const {
2354 auto *configSet
= match
.pattern
->getConfigSet();
2356 configSet
->notifyRewriteBegin(rewriter
);
2358 // The arguments of the rewrite function are stored at the start of the
2360 llvm::copy(match
.values
, state
.memory
.begin());
2362 ByteCodeExecutor
executor(
2363 &rewriterByteCode
[match
.pattern
->getRewriterAddr()], state
.memory
,
2364 state
.opRangeMemory
, state
.typeRangeMemory
,
2365 state
.allocatedTypeRangeMemory
, state
.valueRangeMemory
,
2366 state
.allocatedValueRangeMemory
, state
.loopIndex
, uniquedData
,
2367 rewriterByteCode
, state
.currentPatternBenefits
, patterns
,
2368 constraintFunctions
, rewriteFunctions
);
2369 LogicalResult result
=
2370 executor
.execute(rewriter
, /*matches=*/nullptr, match
.location
);
2373 configSet
->notifyRewriteEnd(rewriter
);
2375 // If the rewrite failed, check if the pattern rewriter can recover. If it
2376 // can, we can signal to the pattern applicator to keep trying patterns. If it
2377 // doesn't, we need to bail. Bailing here should be fine, given that we have
2378 // no means to propagate such a failure to the user, and it also indicates a
2379 // bug in the user code (i.e. failable rewrites should not be used with
2380 // pattern rewriters that don't support it).
2381 if (failed(result
) && !rewriter
.canRecoverFromRewriteFailure()) {
2382 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2383 llvm::report_fatal_error(
2384 "Native PDL Rewrite failed, but the pattern "
2385 "rewriter doesn't support recovery. Failable pattern rewrites should "
2386 "not be used with pattern rewriters that do not support them.");