[SimplifyCFG] FoldTwoEntryPHINode(): consider *total* speculation cost, not per-BB...
[llvm-complete.git] / include / llvm / Transforms / Scalar / GVNExpression.h
blob3dc4515f85a15cd5c21e54a01090856a39d01a19
1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
35 namespace llvm {
37 class BasicBlock;
38 class Type;
40 namespace GVNExpression {
42 enum ExpressionType {
43 ET_Base,
44 ET_Constant,
45 ET_Variable,
46 ET_Dead,
47 ET_Unknown,
48 ET_BasicStart,
49 ET_Basic,
50 ET_AggregateValue,
51 ET_Phi,
52 ET_MemoryStart,
53 ET_Call,
54 ET_Load,
55 ET_Store,
56 ET_MemoryEnd,
57 ET_BasicEnd
60 class Expression {
61 private:
62 ExpressionType EType;
63 unsigned Opcode;
64 mutable hash_code HashVal = 0;
66 public:
67 Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68 : EType(ET), Opcode(O) {}
69 Expression(const Expression &) = delete;
70 Expression &operator=(const Expression &) = delete;
71 virtual ~Expression();
73 static unsigned getEmptyKey() { return ~0U; }
74 static unsigned getTombstoneKey() { return ~1U; }
76 bool operator!=(const Expression &Other) const { return !(*this == Other); }
77 bool operator==(const Expression &Other) const {
78 if (getOpcode() != Other.getOpcode())
79 return false;
80 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81 return true;
82 // Compare the expression type for anything but load and store.
83 // For load and store we set the opcode to zero to make them equal.
84 if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85 getExpressionType() != Other.getExpressionType())
86 return false;
88 return equals(Other);
91 hash_code getComputedHash() const {
92 // It's theoretically possible for a thing to hash to zero. In that case,
93 // we will just compute the hash a few extra times, which is no worse that
94 // we did before, which was to compute it always.
95 if (static_cast<unsigned>(HashVal) == 0)
96 HashVal = getHashValue();
97 return HashVal;
100 virtual bool equals(const Expression &Other) const { return true; }
102 // Return true if the two expressions are exactly the same, including the
103 // normally ignored fields.
104 virtual bool exactlyEquals(const Expression &Other) const {
105 return getExpressionType() == Other.getExpressionType() && equals(Other);
108 unsigned getOpcode() const { return Opcode; }
109 void setOpcode(unsigned opcode) { Opcode = opcode; }
110 ExpressionType getExpressionType() const { return EType; }
112 // We deliberately leave the expression type out of the hash value.
113 virtual hash_code getHashValue() const { return getOpcode(); }
115 // Debugging support
116 virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117 if (PrintEType)
118 OS << "etype = " << getExpressionType() << ",";
119 OS << "opcode = " << getOpcode() << ", ";
122 void print(raw_ostream &OS) const {
123 OS << "{ ";
124 printInternal(OS, true);
125 OS << "}";
128 LLVM_DUMP_METHOD void dump() const;
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132 E.print(OS);
133 return OS;
136 class BasicExpression : public Expression {
137 private:
138 using RecyclerType = ArrayRecycler<Value *>;
139 using RecyclerCapacity = RecyclerType::Capacity;
141 Value **Operands = nullptr;
142 unsigned MaxOperands;
143 unsigned NumOperands = 0;
144 Type *ValueType = nullptr;
146 public:
147 BasicExpression(unsigned NumOperands)
148 : BasicExpression(NumOperands, ET_Basic) {}
149 BasicExpression(unsigned NumOperands, ExpressionType ET)
150 : Expression(ET), MaxOperands(NumOperands) {}
151 BasicExpression() = delete;
152 BasicExpression(const BasicExpression &) = delete;
153 BasicExpression &operator=(const BasicExpression &) = delete;
154 ~BasicExpression() override;
156 static bool classof(const Expression *EB) {
157 ExpressionType ET = EB->getExpressionType();
158 return ET > ET_BasicStart && ET < ET_BasicEnd;
161 /// Swap two operands. Used during GVN to put commutative operands in
162 /// order.
163 void swapOperands(unsigned First, unsigned Second) {
164 std::swap(Operands[First], Operands[Second]);
167 Value *getOperand(unsigned N) const {
168 assert(Operands && "Operands not allocated");
169 assert(N < NumOperands && "Operand out of range");
170 return Operands[N];
173 void setOperand(unsigned N, Value *V) {
174 assert(Operands && "Operands not allocated before setting");
175 assert(N < NumOperands && "Operand out of range");
176 Operands[N] = V;
179 unsigned getNumOperands() const { return NumOperands; }
181 using op_iterator = Value **;
182 using const_op_iterator = Value *const *;
184 op_iterator op_begin() { return Operands; }
185 op_iterator op_end() { return Operands + NumOperands; }
186 const_op_iterator op_begin() const { return Operands; }
187 const_op_iterator op_end() const { return Operands + NumOperands; }
188 iterator_range<op_iterator> operands() {
189 return iterator_range<op_iterator>(op_begin(), op_end());
191 iterator_range<const_op_iterator> operands() const {
192 return iterator_range<const_op_iterator>(op_begin(), op_end());
195 void op_push_back(Value *Arg) {
196 assert(NumOperands < MaxOperands && "Tried to add too many operands");
197 assert(Operands && "Operandss not allocated before pushing");
198 Operands[NumOperands++] = Arg;
200 bool op_empty() const { return getNumOperands() == 0; }
202 void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203 assert(!Operands && "Operands already allocated");
204 Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
206 void deallocateOperands(RecyclerType &Recycler) {
207 Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
210 void setType(Type *T) { ValueType = T; }
211 Type *getType() const { return ValueType; }
213 bool equals(const Expression &Other) const override {
214 if (getOpcode() != Other.getOpcode())
215 return false;
217 const auto &OE = cast<BasicExpression>(Other);
218 return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219 std::equal(op_begin(), op_end(), OE.op_begin());
222 hash_code getHashValue() const override {
223 return hash_combine(this->Expression::getHashValue(), ValueType,
224 hash_combine_range(op_begin(), op_end()));
227 // Debugging support
228 void printInternal(raw_ostream &OS, bool PrintEType) const override {
229 if (PrintEType)
230 OS << "ExpressionTypeBasic, ";
232 this->Expression::printInternal(OS, false);
233 OS << "operands = {";
234 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235 OS << "[" << i << "] = ";
236 Operands[i]->printAsOperand(OS);
237 OS << " ";
239 OS << "} ";
243 class op_inserter
244 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
245 private:
246 using Container = BasicExpression;
248 Container *BE;
250 public:
251 explicit op_inserter(BasicExpression &E) : BE(&E) {}
252 explicit op_inserter(BasicExpression *E) : BE(E) {}
254 op_inserter &operator=(Value *val) {
255 BE->op_push_back(val);
256 return *this;
258 op_inserter &operator*() { return *this; }
259 op_inserter &operator++() { return *this; }
260 op_inserter &operator++(int) { return *this; }
263 class MemoryExpression : public BasicExpression {
264 private:
265 const MemoryAccess *MemoryLeader;
267 public:
268 MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
269 const MemoryAccess *MemoryLeader)
270 : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
271 MemoryExpression() = delete;
272 MemoryExpression(const MemoryExpression &) = delete;
273 MemoryExpression &operator=(const MemoryExpression &) = delete;
275 static bool classof(const Expression *EB) {
276 return EB->getExpressionType() > ET_MemoryStart &&
277 EB->getExpressionType() < ET_MemoryEnd;
280 hash_code getHashValue() const override {
281 return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
284 bool equals(const Expression &Other) const override {
285 if (!this->BasicExpression::equals(Other))
286 return false;
287 const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
289 return MemoryLeader == OtherMCE.MemoryLeader;
292 const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
293 void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
296 class CallExpression final : public MemoryExpression {
297 private:
298 CallInst *Call;
300 public:
301 CallExpression(unsigned NumOperands, CallInst *C,
302 const MemoryAccess *MemoryLeader)
303 : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
304 CallExpression() = delete;
305 CallExpression(const CallExpression &) = delete;
306 CallExpression &operator=(const CallExpression &) = delete;
307 ~CallExpression() override;
309 static bool classof(const Expression *EB) {
310 return EB->getExpressionType() == ET_Call;
313 // Debugging support
314 void printInternal(raw_ostream &OS, bool PrintEType) const override {
315 if (PrintEType)
316 OS << "ExpressionTypeCall, ";
317 this->BasicExpression::printInternal(OS, false);
318 OS << " represents call at ";
319 Call->printAsOperand(OS);
323 class LoadExpression final : public MemoryExpression {
324 private:
325 LoadInst *Load;
326 unsigned Alignment;
328 public:
329 LoadExpression(unsigned NumOperands, LoadInst *L,
330 const MemoryAccess *MemoryLeader)
331 : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
333 LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
334 const MemoryAccess *MemoryLeader)
335 : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
336 Alignment = L ? L->getAlignment() : 0;
339 LoadExpression() = delete;
340 LoadExpression(const LoadExpression &) = delete;
341 LoadExpression &operator=(const LoadExpression &) = delete;
342 ~LoadExpression() override;
344 static bool classof(const Expression *EB) {
345 return EB->getExpressionType() == ET_Load;
348 LoadInst *getLoadInst() const { return Load; }
349 void setLoadInst(LoadInst *L) { Load = L; }
351 unsigned getAlignment() const { return Alignment; }
352 void setAlignment(unsigned Align) { Alignment = Align; }
354 bool equals(const Expression &Other) const override;
355 bool exactlyEquals(const Expression &Other) const override {
356 return Expression::exactlyEquals(Other) &&
357 cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
360 // Debugging support
361 void printInternal(raw_ostream &OS, bool PrintEType) const override {
362 if (PrintEType)
363 OS << "ExpressionTypeLoad, ";
364 this->BasicExpression::printInternal(OS, false);
365 OS << " represents Load at ";
366 Load->printAsOperand(OS);
367 OS << " with MemoryLeader " << *getMemoryLeader();
371 class StoreExpression final : public MemoryExpression {
372 private:
373 StoreInst *Store;
374 Value *StoredValue;
376 public:
377 StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
378 const MemoryAccess *MemoryLeader)
379 : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
380 StoredValue(StoredValue) {}
381 StoreExpression() = delete;
382 StoreExpression(const StoreExpression &) = delete;
383 StoreExpression &operator=(const StoreExpression &) = delete;
384 ~StoreExpression() override;
386 static bool classof(const Expression *EB) {
387 return EB->getExpressionType() == ET_Store;
390 StoreInst *getStoreInst() const { return Store; }
391 Value *getStoredValue() const { return StoredValue; }
393 bool equals(const Expression &Other) const override;
395 bool exactlyEquals(const Expression &Other) const override {
396 return Expression::exactlyEquals(Other) &&
397 cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
400 // Debugging support
401 void printInternal(raw_ostream &OS, bool PrintEType) const override {
402 if (PrintEType)
403 OS << "ExpressionTypeStore, ";
404 this->BasicExpression::printInternal(OS, false);
405 OS << " represents Store " << *Store;
406 OS << " with StoredValue ";
407 StoredValue->printAsOperand(OS);
408 OS << " and MemoryLeader " << *getMemoryLeader();
412 class AggregateValueExpression final : public BasicExpression {
413 private:
414 unsigned MaxIntOperands;
415 unsigned NumIntOperands = 0;
416 unsigned *IntOperands = nullptr;
418 public:
419 AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
420 : BasicExpression(NumOperands, ET_AggregateValue),
421 MaxIntOperands(NumIntOperands) {}
422 AggregateValueExpression() = delete;
423 AggregateValueExpression(const AggregateValueExpression &) = delete;
424 AggregateValueExpression &
425 operator=(const AggregateValueExpression &) = delete;
426 ~AggregateValueExpression() override;
428 static bool classof(const Expression *EB) {
429 return EB->getExpressionType() == ET_AggregateValue;
432 using int_arg_iterator = unsigned *;
433 using const_int_arg_iterator = const unsigned *;
435 int_arg_iterator int_op_begin() { return IntOperands; }
436 int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
437 const_int_arg_iterator int_op_begin() const { return IntOperands; }
438 const_int_arg_iterator int_op_end() const {
439 return IntOperands + NumIntOperands;
441 unsigned int_op_size() const { return NumIntOperands; }
442 bool int_op_empty() const { return NumIntOperands == 0; }
443 void int_op_push_back(unsigned IntOperand) {
444 assert(NumIntOperands < MaxIntOperands &&
445 "Tried to add too many int operands");
446 assert(IntOperands && "Operands not allocated before pushing");
447 IntOperands[NumIntOperands++] = IntOperand;
450 virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
451 assert(!IntOperands && "Operands already allocated");
452 IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
455 bool equals(const Expression &Other) const override {
456 if (!this->BasicExpression::equals(Other))
457 return false;
458 const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
459 return NumIntOperands == OE.NumIntOperands &&
460 std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
463 hash_code getHashValue() const override {
464 return hash_combine(this->BasicExpression::getHashValue(),
465 hash_combine_range(int_op_begin(), int_op_end()));
468 // Debugging support
469 void printInternal(raw_ostream &OS, bool PrintEType) const override {
470 if (PrintEType)
471 OS << "ExpressionTypeAggregateValue, ";
472 this->BasicExpression::printInternal(OS, false);
473 OS << ", intoperands = {";
474 for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
475 OS << "[" << i << "] = " << IntOperands[i] << " ";
477 OS << "}";
481 class int_op_inserter
482 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
483 private:
484 using Container = AggregateValueExpression;
486 Container *AVE;
488 public:
489 explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
490 explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
492 int_op_inserter &operator=(unsigned int val) {
493 AVE->int_op_push_back(val);
494 return *this;
496 int_op_inserter &operator*() { return *this; }
497 int_op_inserter &operator++() { return *this; }
498 int_op_inserter &operator++(int) { return *this; }
501 class PHIExpression final : public BasicExpression {
502 private:
503 BasicBlock *BB;
505 public:
506 PHIExpression(unsigned NumOperands, BasicBlock *B)
507 : BasicExpression(NumOperands, ET_Phi), BB(B) {}
508 PHIExpression() = delete;
509 PHIExpression(const PHIExpression &) = delete;
510 PHIExpression &operator=(const PHIExpression &) = delete;
511 ~PHIExpression() override;
513 static bool classof(const Expression *EB) {
514 return EB->getExpressionType() == ET_Phi;
517 bool equals(const Expression &Other) const override {
518 if (!this->BasicExpression::equals(Other))
519 return false;
520 const PHIExpression &OE = cast<PHIExpression>(Other);
521 return BB == OE.BB;
524 hash_code getHashValue() const override {
525 return hash_combine(this->BasicExpression::getHashValue(), BB);
528 // Debugging support
529 void printInternal(raw_ostream &OS, bool PrintEType) const override {
530 if (PrintEType)
531 OS << "ExpressionTypePhi, ";
532 this->BasicExpression::printInternal(OS, false);
533 OS << "bb = " << BB;
537 class DeadExpression final : public Expression {
538 public:
539 DeadExpression() : Expression(ET_Dead) {}
540 DeadExpression(const DeadExpression &) = delete;
541 DeadExpression &operator=(const DeadExpression &) = delete;
543 static bool classof(const Expression *E) {
544 return E->getExpressionType() == ET_Dead;
548 class VariableExpression final : public Expression {
549 private:
550 Value *VariableValue;
552 public:
553 VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
554 VariableExpression() = delete;
555 VariableExpression(const VariableExpression &) = delete;
556 VariableExpression &operator=(const VariableExpression &) = delete;
558 static bool classof(const Expression *EB) {
559 return EB->getExpressionType() == ET_Variable;
562 Value *getVariableValue() const { return VariableValue; }
563 void setVariableValue(Value *V) { VariableValue = V; }
565 bool equals(const Expression &Other) const override {
566 const VariableExpression &OC = cast<VariableExpression>(Other);
567 return VariableValue == OC.VariableValue;
570 hash_code getHashValue() const override {
571 return hash_combine(this->Expression::getHashValue(),
572 VariableValue->getType(), VariableValue);
575 // Debugging support
576 void printInternal(raw_ostream &OS, bool PrintEType) const override {
577 if (PrintEType)
578 OS << "ExpressionTypeVariable, ";
579 this->Expression::printInternal(OS, false);
580 OS << " variable = " << *VariableValue;
584 class ConstantExpression final : public Expression {
585 private:
586 Constant *ConstantValue = nullptr;
588 public:
589 ConstantExpression() : Expression(ET_Constant) {}
590 ConstantExpression(Constant *constantValue)
591 : Expression(ET_Constant), ConstantValue(constantValue) {}
592 ConstantExpression(const ConstantExpression &) = delete;
593 ConstantExpression &operator=(const ConstantExpression &) = delete;
595 static bool classof(const Expression *EB) {
596 return EB->getExpressionType() == ET_Constant;
599 Constant *getConstantValue() const { return ConstantValue; }
600 void setConstantValue(Constant *V) { ConstantValue = V; }
602 bool equals(const Expression &Other) const override {
603 const ConstantExpression &OC = cast<ConstantExpression>(Other);
604 return ConstantValue == OC.ConstantValue;
607 hash_code getHashValue() const override {
608 return hash_combine(this->Expression::getHashValue(),
609 ConstantValue->getType(), ConstantValue);
612 // Debugging support
613 void printInternal(raw_ostream &OS, bool PrintEType) const override {
614 if (PrintEType)
615 OS << "ExpressionTypeConstant, ";
616 this->Expression::printInternal(OS, false);
617 OS << " constant = " << *ConstantValue;
621 class UnknownExpression final : public Expression {
622 private:
623 Instruction *Inst;
625 public:
626 UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
627 UnknownExpression() = delete;
628 UnknownExpression(const UnknownExpression &) = delete;
629 UnknownExpression &operator=(const UnknownExpression &) = delete;
631 static bool classof(const Expression *EB) {
632 return EB->getExpressionType() == ET_Unknown;
635 Instruction *getInstruction() const { return Inst; }
636 void setInstruction(Instruction *I) { Inst = I; }
638 bool equals(const Expression &Other) const override {
639 const auto &OU = cast<UnknownExpression>(Other);
640 return Inst == OU.Inst;
643 hash_code getHashValue() const override {
644 return hash_combine(this->Expression::getHashValue(), Inst);
647 // Debugging support
648 void printInternal(raw_ostream &OS, bool PrintEType) const override {
649 if (PrintEType)
650 OS << "ExpressionTypeUnknown, ";
651 this->Expression::printInternal(OS, false);
652 OS << " inst = " << *Inst;
656 } // end namespace GVNExpression
658 } // end namespace llvm
660 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H