[InstCombine] Signed saturation patterns
[llvm-core.git] / include / llvm / Transforms / Scalar / GVNExpression.h
blob1600d1af32426bbab2bd081df0cda554f29970d5
1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 ///
11 /// The header file for the GVN pass that contains expression handling
12 /// classes
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <algorithm>
31 #include <cassert>
32 #include <iterator>
33 #include <utility>
35 namespace llvm {
37 class BasicBlock;
38 class Type;
40 namespace GVNExpression {
42 enum ExpressionType {
43 ET_Base,
44 ET_Constant,
45 ET_Variable,
46 ET_Dead,
47 ET_Unknown,
48 ET_BasicStart,
49 ET_Basic,
50 ET_AggregateValue,
51 ET_Phi,
52 ET_MemoryStart,
53 ET_Call,
54 ET_Load,
55 ET_Store,
56 ET_MemoryEnd,
57 ET_BasicEnd
60 class Expression {
61 private:
62 ExpressionType EType;
63 unsigned Opcode;
64 mutable hash_code HashVal = 0;
66 public:
67 Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68 : EType(ET), Opcode(O) {}
69 Expression(const Expression &) = delete;
70 Expression &operator=(const Expression &) = delete;
71 virtual ~Expression();
73 static unsigned getEmptyKey() { return ~0U; }
74 static unsigned getTombstoneKey() { return ~1U; }
76 bool operator!=(const Expression &Other) const { return !(*this == Other); }
77 bool operator==(const Expression &Other) const {
78 if (getOpcode() != Other.getOpcode())
79 return false;
80 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81 return true;
82 // Compare the expression type for anything but load and store.
83 // For load and store we set the opcode to zero to make them equal.
84 if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85 getExpressionType() != Other.getExpressionType())
86 return false;
88 return equals(Other);
91 hash_code getComputedHash() const {
92 // It's theoretically possible for a thing to hash to zero. In that case,
93 // we will just compute the hash a few extra times, which is no worse that
94 // we did before, which was to compute it always.
95 if (static_cast<unsigned>(HashVal) == 0)
96 HashVal = getHashValue();
97 return HashVal;
100 virtual bool equals(const Expression &Other) const { return true; }
102 // Return true if the two expressions are exactly the same, including the
103 // normally ignored fields.
104 virtual bool exactlyEquals(const Expression &Other) const {
105 return getExpressionType() == Other.getExpressionType() && equals(Other);
108 unsigned getOpcode() const { return Opcode; }
109 void setOpcode(unsigned opcode) { Opcode = opcode; }
110 ExpressionType getExpressionType() const { return EType; }
112 // We deliberately leave the expression type out of the hash value.
113 virtual hash_code getHashValue() const { return getOpcode(); }
115 // Debugging support
116 virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117 if (PrintEType)
118 OS << "etype = " << getExpressionType() << ",";
119 OS << "opcode = " << getOpcode() << ", ";
122 void print(raw_ostream &OS) const {
123 OS << "{ ";
124 printInternal(OS, true);
125 OS << "}";
128 LLVM_DUMP_METHOD void dump() const;
131 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132 E.print(OS);
133 return OS;
136 class BasicExpression : public Expression {
137 private:
138 using RecyclerType = ArrayRecycler<Value *>;
139 using RecyclerCapacity = RecyclerType::Capacity;
141 Value **Operands = nullptr;
142 unsigned MaxOperands;
143 unsigned NumOperands = 0;
144 Type *ValueType = nullptr;
146 public:
147 BasicExpression(unsigned NumOperands)
148 : BasicExpression(NumOperands, ET_Basic) {}
149 BasicExpression(unsigned NumOperands, ExpressionType ET)
150 : Expression(ET), MaxOperands(NumOperands) {}
151 BasicExpression() = delete;
152 BasicExpression(const BasicExpression &) = delete;
153 BasicExpression &operator=(const BasicExpression &) = delete;
154 ~BasicExpression() override;
156 static bool classof(const Expression *EB) {
157 ExpressionType ET = EB->getExpressionType();
158 return ET > ET_BasicStart && ET < ET_BasicEnd;
161 /// Swap two operands. Used during GVN to put commutative operands in
162 /// order.
163 void swapOperands(unsigned First, unsigned Second) {
164 std::swap(Operands[First], Operands[Second]);
167 Value *getOperand(unsigned N) const {
168 assert(Operands && "Operands not allocated");
169 assert(N < NumOperands && "Operand out of range");
170 return Operands[N];
173 void setOperand(unsigned N, Value *V) {
174 assert(Operands && "Operands not allocated before setting");
175 assert(N < NumOperands && "Operand out of range");
176 Operands[N] = V;
179 unsigned getNumOperands() const { return NumOperands; }
181 using op_iterator = Value **;
182 using const_op_iterator = Value *const *;
184 op_iterator op_begin() { return Operands; }
185 op_iterator op_end() { return Operands + NumOperands; }
186 const_op_iterator op_begin() const { return Operands; }
187 const_op_iterator op_end() const { return Operands + NumOperands; }
188 iterator_range<op_iterator> operands() {
189 return iterator_range<op_iterator>(op_begin(), op_end());
191 iterator_range<const_op_iterator> operands() const {
192 return iterator_range<const_op_iterator>(op_begin(), op_end());
195 void op_push_back(Value *Arg) {
196 assert(NumOperands < MaxOperands && "Tried to add too many operands");
197 assert(Operands && "Operandss not allocated before pushing");
198 Operands[NumOperands++] = Arg;
200 bool op_empty() const { return getNumOperands() == 0; }
202 void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203 assert(!Operands && "Operands already allocated");
204 Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
206 void deallocateOperands(RecyclerType &Recycler) {
207 Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
210 void setType(Type *T) { ValueType = T; }
211 Type *getType() const { return ValueType; }
213 bool equals(const Expression &Other) const override {
214 if (getOpcode() != Other.getOpcode())
215 return false;
217 const auto &OE = cast<BasicExpression>(Other);
218 return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219 std::equal(op_begin(), op_end(), OE.op_begin());
222 hash_code getHashValue() const override {
223 return hash_combine(this->Expression::getHashValue(), ValueType,
224 hash_combine_range(op_begin(), op_end()));
227 // Debugging support
228 void printInternal(raw_ostream &OS, bool PrintEType) const override {
229 if (PrintEType)
230 OS << "ExpressionTypeBasic, ";
232 this->Expression::printInternal(OS, false);
233 OS << "operands = {";
234 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235 OS << "[" << i << "] = ";
236 Operands[i]->printAsOperand(OS);
237 OS << " ";
239 OS << "} ";
243 class op_inserter
244 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
245 private:
246 using Container = BasicExpression;
248 Container *BE;
250 public:
251 explicit op_inserter(BasicExpression &E) : BE(&E) {}
252 explicit op_inserter(BasicExpression *E) : BE(E) {}
254 op_inserter &operator=(Value *val) {
255 BE->op_push_back(val);
256 return *this;
258 op_inserter &operator*() { return *this; }
259 op_inserter &operator++() { return *this; }
260 op_inserter &operator++(int) { return *this; }
263 class MemoryExpression : public BasicExpression {
264 private:
265 const MemoryAccess *MemoryLeader;
267 public:
268 MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
269 const MemoryAccess *MemoryLeader)
270 : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
271 MemoryExpression() = delete;
272 MemoryExpression(const MemoryExpression &) = delete;
273 MemoryExpression &operator=(const MemoryExpression &) = delete;
275 static bool classof(const Expression *EB) {
276 return EB->getExpressionType() > ET_MemoryStart &&
277 EB->getExpressionType() < ET_MemoryEnd;
280 hash_code getHashValue() const override {
281 return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
284 bool equals(const Expression &Other) const override {
285 if (!this->BasicExpression::equals(Other))
286 return false;
287 const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
289 return MemoryLeader == OtherMCE.MemoryLeader;
292 const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
293 void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
296 class CallExpression final : public MemoryExpression {
297 private:
298 CallInst *Call;
300 public:
301 CallExpression(unsigned NumOperands, CallInst *C,
302 const MemoryAccess *MemoryLeader)
303 : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
304 CallExpression() = delete;
305 CallExpression(const CallExpression &) = delete;
306 CallExpression &operator=(const CallExpression &) = delete;
307 ~CallExpression() override;
309 static bool classof(const Expression *EB) {
310 return EB->getExpressionType() == ET_Call;
313 // Debugging support
314 void printInternal(raw_ostream &OS, bool PrintEType) const override {
315 if (PrintEType)
316 OS << "ExpressionTypeCall, ";
317 this->BasicExpression::printInternal(OS, false);
318 OS << " represents call at ";
319 Call->printAsOperand(OS);
323 class LoadExpression final : public MemoryExpression {
324 private:
325 LoadInst *Load;
326 MaybeAlign Alignment;
328 public:
329 LoadExpression(unsigned NumOperands, LoadInst *L,
330 const MemoryAccess *MemoryLeader)
331 : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
333 LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
334 const MemoryAccess *MemoryLeader)
335 : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
336 if (L)
337 Alignment = MaybeAlign(L->getAlignment());
340 LoadExpression() = delete;
341 LoadExpression(const LoadExpression &) = delete;
342 LoadExpression &operator=(const LoadExpression &) = delete;
343 ~LoadExpression() override;
345 static bool classof(const Expression *EB) {
346 return EB->getExpressionType() == ET_Load;
349 LoadInst *getLoadInst() const { return Load; }
350 void setLoadInst(LoadInst *L) { Load = L; }
352 MaybeAlign getAlignment() const { return Alignment; }
353 void setAlignment(MaybeAlign Align) { Alignment = Align; }
355 bool equals(const Expression &Other) const override;
356 bool exactlyEquals(const Expression &Other) const override {
357 return Expression::exactlyEquals(Other) &&
358 cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
361 // Debugging support
362 void printInternal(raw_ostream &OS, bool PrintEType) const override {
363 if (PrintEType)
364 OS << "ExpressionTypeLoad, ";
365 this->BasicExpression::printInternal(OS, false);
366 OS << " represents Load at ";
367 Load->printAsOperand(OS);
368 OS << " with MemoryLeader " << *getMemoryLeader();
372 class StoreExpression final : public MemoryExpression {
373 private:
374 StoreInst *Store;
375 Value *StoredValue;
377 public:
378 StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
379 const MemoryAccess *MemoryLeader)
380 : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
381 StoredValue(StoredValue) {}
382 StoreExpression() = delete;
383 StoreExpression(const StoreExpression &) = delete;
384 StoreExpression &operator=(const StoreExpression &) = delete;
385 ~StoreExpression() override;
387 static bool classof(const Expression *EB) {
388 return EB->getExpressionType() == ET_Store;
391 StoreInst *getStoreInst() const { return Store; }
392 Value *getStoredValue() const { return StoredValue; }
394 bool equals(const Expression &Other) const override;
396 bool exactlyEquals(const Expression &Other) const override {
397 return Expression::exactlyEquals(Other) &&
398 cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
401 // Debugging support
402 void printInternal(raw_ostream &OS, bool PrintEType) const override {
403 if (PrintEType)
404 OS << "ExpressionTypeStore, ";
405 this->BasicExpression::printInternal(OS, false);
406 OS << " represents Store " << *Store;
407 OS << " with StoredValue ";
408 StoredValue->printAsOperand(OS);
409 OS << " and MemoryLeader " << *getMemoryLeader();
413 class AggregateValueExpression final : public BasicExpression {
414 private:
415 unsigned MaxIntOperands;
416 unsigned NumIntOperands = 0;
417 unsigned *IntOperands = nullptr;
419 public:
420 AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
421 : BasicExpression(NumOperands, ET_AggregateValue),
422 MaxIntOperands(NumIntOperands) {}
423 AggregateValueExpression() = delete;
424 AggregateValueExpression(const AggregateValueExpression &) = delete;
425 AggregateValueExpression &
426 operator=(const AggregateValueExpression &) = delete;
427 ~AggregateValueExpression() override;
429 static bool classof(const Expression *EB) {
430 return EB->getExpressionType() == ET_AggregateValue;
433 using int_arg_iterator = unsigned *;
434 using const_int_arg_iterator = const unsigned *;
436 int_arg_iterator int_op_begin() { return IntOperands; }
437 int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
438 const_int_arg_iterator int_op_begin() const { return IntOperands; }
439 const_int_arg_iterator int_op_end() const {
440 return IntOperands + NumIntOperands;
442 unsigned int_op_size() const { return NumIntOperands; }
443 bool int_op_empty() const { return NumIntOperands == 0; }
444 void int_op_push_back(unsigned IntOperand) {
445 assert(NumIntOperands < MaxIntOperands &&
446 "Tried to add too many int operands");
447 assert(IntOperands && "Operands not allocated before pushing");
448 IntOperands[NumIntOperands++] = IntOperand;
451 virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
452 assert(!IntOperands && "Operands already allocated");
453 IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
456 bool equals(const Expression &Other) const override {
457 if (!this->BasicExpression::equals(Other))
458 return false;
459 const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
460 return NumIntOperands == OE.NumIntOperands &&
461 std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
464 hash_code getHashValue() const override {
465 return hash_combine(this->BasicExpression::getHashValue(),
466 hash_combine_range(int_op_begin(), int_op_end()));
469 // Debugging support
470 void printInternal(raw_ostream &OS, bool PrintEType) const override {
471 if (PrintEType)
472 OS << "ExpressionTypeAggregateValue, ";
473 this->BasicExpression::printInternal(OS, false);
474 OS << ", intoperands = {";
475 for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
476 OS << "[" << i << "] = " << IntOperands[i] << " ";
478 OS << "}";
482 class int_op_inserter
483 : public std::iterator<std::output_iterator_tag, void, void, void, void> {
484 private:
485 using Container = AggregateValueExpression;
487 Container *AVE;
489 public:
490 explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
491 explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
493 int_op_inserter &operator=(unsigned int val) {
494 AVE->int_op_push_back(val);
495 return *this;
497 int_op_inserter &operator*() { return *this; }
498 int_op_inserter &operator++() { return *this; }
499 int_op_inserter &operator++(int) { return *this; }
502 class PHIExpression final : public BasicExpression {
503 private:
504 BasicBlock *BB;
506 public:
507 PHIExpression(unsigned NumOperands, BasicBlock *B)
508 : BasicExpression(NumOperands, ET_Phi), BB(B) {}
509 PHIExpression() = delete;
510 PHIExpression(const PHIExpression &) = delete;
511 PHIExpression &operator=(const PHIExpression &) = delete;
512 ~PHIExpression() override;
514 static bool classof(const Expression *EB) {
515 return EB->getExpressionType() == ET_Phi;
518 bool equals(const Expression &Other) const override {
519 if (!this->BasicExpression::equals(Other))
520 return false;
521 const PHIExpression &OE = cast<PHIExpression>(Other);
522 return BB == OE.BB;
525 hash_code getHashValue() const override {
526 return hash_combine(this->BasicExpression::getHashValue(), BB);
529 // Debugging support
530 void printInternal(raw_ostream &OS, bool PrintEType) const override {
531 if (PrintEType)
532 OS << "ExpressionTypePhi, ";
533 this->BasicExpression::printInternal(OS, false);
534 OS << "bb = " << BB;
538 class DeadExpression final : public Expression {
539 public:
540 DeadExpression() : Expression(ET_Dead) {}
541 DeadExpression(const DeadExpression &) = delete;
542 DeadExpression &operator=(const DeadExpression &) = delete;
544 static bool classof(const Expression *E) {
545 return E->getExpressionType() == ET_Dead;
549 class VariableExpression final : public Expression {
550 private:
551 Value *VariableValue;
553 public:
554 VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
555 VariableExpression() = delete;
556 VariableExpression(const VariableExpression &) = delete;
557 VariableExpression &operator=(const VariableExpression &) = delete;
559 static bool classof(const Expression *EB) {
560 return EB->getExpressionType() == ET_Variable;
563 Value *getVariableValue() const { return VariableValue; }
564 void setVariableValue(Value *V) { VariableValue = V; }
566 bool equals(const Expression &Other) const override {
567 const VariableExpression &OC = cast<VariableExpression>(Other);
568 return VariableValue == OC.VariableValue;
571 hash_code getHashValue() const override {
572 return hash_combine(this->Expression::getHashValue(),
573 VariableValue->getType(), VariableValue);
576 // Debugging support
577 void printInternal(raw_ostream &OS, bool PrintEType) const override {
578 if (PrintEType)
579 OS << "ExpressionTypeVariable, ";
580 this->Expression::printInternal(OS, false);
581 OS << " variable = " << *VariableValue;
585 class ConstantExpression final : public Expression {
586 private:
587 Constant *ConstantValue = nullptr;
589 public:
590 ConstantExpression() : Expression(ET_Constant) {}
591 ConstantExpression(Constant *constantValue)
592 : Expression(ET_Constant), ConstantValue(constantValue) {}
593 ConstantExpression(const ConstantExpression &) = delete;
594 ConstantExpression &operator=(const ConstantExpression &) = delete;
596 static bool classof(const Expression *EB) {
597 return EB->getExpressionType() == ET_Constant;
600 Constant *getConstantValue() const { return ConstantValue; }
601 void setConstantValue(Constant *V) { ConstantValue = V; }
603 bool equals(const Expression &Other) const override {
604 const ConstantExpression &OC = cast<ConstantExpression>(Other);
605 return ConstantValue == OC.ConstantValue;
608 hash_code getHashValue() const override {
609 return hash_combine(this->Expression::getHashValue(),
610 ConstantValue->getType(), ConstantValue);
613 // Debugging support
614 void printInternal(raw_ostream &OS, bool PrintEType) const override {
615 if (PrintEType)
616 OS << "ExpressionTypeConstant, ";
617 this->Expression::printInternal(OS, false);
618 OS << " constant = " << *ConstantValue;
622 class UnknownExpression final : public Expression {
623 private:
624 Instruction *Inst;
626 public:
627 UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
628 UnknownExpression() = delete;
629 UnknownExpression(const UnknownExpression &) = delete;
630 UnknownExpression &operator=(const UnknownExpression &) = delete;
632 static bool classof(const Expression *EB) {
633 return EB->getExpressionType() == ET_Unknown;
636 Instruction *getInstruction() const { return Inst; }
637 void setInstruction(Instruction *I) { Inst = I; }
639 bool equals(const Expression &Other) const override {
640 const auto &OU = cast<UnknownExpression>(Other);
641 return Inst == OU.Inst;
644 hash_code getHashValue() const override {
645 return hash_combine(this->Expression::getHashValue(), Inst);
648 // Debugging support
649 void printInternal(raw_ostream &OS, bool PrintEType) const override {
650 if (PrintEType)
651 OS << "ExpressionTypeUnknown, ";
652 this->Expression::printInternal(OS, false);
653 OS << " inst = " << *Inst;
657 } // end namespace GVNExpression
659 } // end namespace llvm
661 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H