1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 /// The header file for the GVN pass that contains expression handling
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
40 namespace GVNExpression
{
64 mutable hash_code HashVal
= 0;
67 Expression(ExpressionType ET
= ET_Base
, unsigned O
= ~2U)
68 : EType(ET
), Opcode(O
) {}
69 Expression(const Expression
&) = delete;
70 Expression
&operator=(const Expression
&) = delete;
71 virtual ~Expression();
73 static unsigned getEmptyKey() { return ~0U; }
74 static unsigned getTombstoneKey() { return ~1U; }
76 bool operator!=(const Expression
&Other
) const { return !(*this == Other
); }
77 bool operator==(const Expression
&Other
) const {
78 if (getOpcode() != Other
.getOpcode())
80 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
82 // Compare the expression type for anything but load and store.
83 // For load and store we set the opcode to zero to make them equal.
84 if (getExpressionType() != ET_Load
&& getExpressionType() != ET_Store
&&
85 getExpressionType() != Other
.getExpressionType())
91 hash_code
getComputedHash() const {
92 // It's theoretically possible for a thing to hash to zero. In that case,
93 // we will just compute the hash a few extra times, which is no worse that
94 // we did before, which was to compute it always.
95 if (static_cast<unsigned>(HashVal
) == 0)
96 HashVal
= getHashValue();
100 virtual bool equals(const Expression
&Other
) const { return true; }
102 // Return true if the two expressions are exactly the same, including the
103 // normally ignored fields.
104 virtual bool exactlyEquals(const Expression
&Other
) const {
105 return getExpressionType() == Other
.getExpressionType() && equals(Other
);
108 unsigned getOpcode() const { return Opcode
; }
109 void setOpcode(unsigned opcode
) { Opcode
= opcode
; }
110 ExpressionType
getExpressionType() const { return EType
; }
112 // We deliberately leave the expression type out of the hash value.
113 virtual hash_code
getHashValue() const { return getOpcode(); }
116 virtual void printInternal(raw_ostream
&OS
, bool PrintEType
) const {
118 OS
<< "etype = " << getExpressionType() << ",";
119 OS
<< "opcode = " << getOpcode() << ", ";
122 void print(raw_ostream
&OS
) const {
124 printInternal(OS
, true);
128 LLVM_DUMP_METHOD
void dump() const;
131 inline raw_ostream
&operator<<(raw_ostream
&OS
, const Expression
&E
) {
136 class BasicExpression
: public Expression
{
138 using RecyclerType
= ArrayRecycler
<Value
*>;
139 using RecyclerCapacity
= RecyclerType::Capacity
;
141 Value
**Operands
= nullptr;
142 unsigned MaxOperands
;
143 unsigned NumOperands
= 0;
144 Type
*ValueType
= nullptr;
147 BasicExpression(unsigned NumOperands
)
148 : BasicExpression(NumOperands
, ET_Basic
) {}
149 BasicExpression(unsigned NumOperands
, ExpressionType ET
)
150 : Expression(ET
), MaxOperands(NumOperands
) {}
151 BasicExpression() = delete;
152 BasicExpression(const BasicExpression
&) = delete;
153 BasicExpression
&operator=(const BasicExpression
&) = delete;
154 ~BasicExpression() override
;
156 static bool classof(const Expression
*EB
) {
157 ExpressionType ET
= EB
->getExpressionType();
158 return ET
> ET_BasicStart
&& ET
< ET_BasicEnd
;
161 /// Swap two operands. Used during GVN to put commutative operands in
163 void swapOperands(unsigned First
, unsigned Second
) {
164 std::swap(Operands
[First
], Operands
[Second
]);
167 Value
*getOperand(unsigned N
) const {
168 assert(Operands
&& "Operands not allocated");
169 assert(N
< NumOperands
&& "Operand out of range");
173 void setOperand(unsigned N
, Value
*V
) {
174 assert(Operands
&& "Operands not allocated before setting");
175 assert(N
< NumOperands
&& "Operand out of range");
179 unsigned getNumOperands() const { return NumOperands
; }
181 using op_iterator
= Value
**;
182 using const_op_iterator
= Value
*const *;
184 op_iterator
op_begin() { return Operands
; }
185 op_iterator
op_end() { return Operands
+ NumOperands
; }
186 const_op_iterator
op_begin() const { return Operands
; }
187 const_op_iterator
op_end() const { return Operands
+ NumOperands
; }
188 iterator_range
<op_iterator
> operands() {
189 return iterator_range
<op_iterator
>(op_begin(), op_end());
191 iterator_range
<const_op_iterator
> operands() const {
192 return iterator_range
<const_op_iterator
>(op_begin(), op_end());
195 void op_push_back(Value
*Arg
) {
196 assert(NumOperands
< MaxOperands
&& "Tried to add too many operands");
197 assert(Operands
&& "Operandss not allocated before pushing");
198 Operands
[NumOperands
++] = Arg
;
200 bool op_empty() const { return getNumOperands() == 0; }
202 void allocateOperands(RecyclerType
&Recycler
, BumpPtrAllocator
&Allocator
) {
203 assert(!Operands
&& "Operands already allocated");
204 Operands
= Recycler
.allocate(RecyclerCapacity::get(MaxOperands
), Allocator
);
206 void deallocateOperands(RecyclerType
&Recycler
) {
207 Recycler
.deallocate(RecyclerCapacity::get(MaxOperands
), Operands
);
210 void setType(Type
*T
) { ValueType
= T
; }
211 Type
*getType() const { return ValueType
; }
213 bool equals(const Expression
&Other
) const override
{
214 if (getOpcode() != Other
.getOpcode())
217 const auto &OE
= cast
<BasicExpression
>(Other
);
218 return getType() == OE
.getType() && NumOperands
== OE
.NumOperands
&&
219 std::equal(op_begin(), op_end(), OE
.op_begin());
222 hash_code
getHashValue() const override
{
223 return hash_combine(this->Expression::getHashValue(), ValueType
,
224 hash_combine_range(op_begin(), op_end()));
228 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
230 OS
<< "ExpressionTypeBasic, ";
232 this->Expression::printInternal(OS
, false);
233 OS
<< "operands = {";
234 for (unsigned i
= 0, e
= getNumOperands(); i
!= e
; ++i
) {
235 OS
<< "[" << i
<< "] = ";
236 Operands
[i
]->printAsOperand(OS
);
244 : public std::iterator
<std::output_iterator_tag
, void, void, void, void> {
246 using Container
= BasicExpression
;
251 explicit op_inserter(BasicExpression
&E
) : BE(&E
) {}
252 explicit op_inserter(BasicExpression
*E
) : BE(E
) {}
254 op_inserter
&operator=(Value
*val
) {
255 BE
->op_push_back(val
);
258 op_inserter
&operator*() { return *this; }
259 op_inserter
&operator++() { return *this; }
260 op_inserter
&operator++(int) { return *this; }
263 class MemoryExpression
: public BasicExpression
{
265 const MemoryAccess
*MemoryLeader
;
268 MemoryExpression(unsigned NumOperands
, enum ExpressionType EType
,
269 const MemoryAccess
*MemoryLeader
)
270 : BasicExpression(NumOperands
, EType
), MemoryLeader(MemoryLeader
) {}
271 MemoryExpression() = delete;
272 MemoryExpression(const MemoryExpression
&) = delete;
273 MemoryExpression
&operator=(const MemoryExpression
&) = delete;
275 static bool classof(const Expression
*EB
) {
276 return EB
->getExpressionType() > ET_MemoryStart
&&
277 EB
->getExpressionType() < ET_MemoryEnd
;
280 hash_code
getHashValue() const override
{
281 return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader
);
284 bool equals(const Expression
&Other
) const override
{
285 if (!this->BasicExpression::equals(Other
))
287 const MemoryExpression
&OtherMCE
= cast
<MemoryExpression
>(Other
);
289 return MemoryLeader
== OtherMCE
.MemoryLeader
;
292 const MemoryAccess
*getMemoryLeader() const { return MemoryLeader
; }
293 void setMemoryLeader(const MemoryAccess
*ML
) { MemoryLeader
= ML
; }
296 class CallExpression final
: public MemoryExpression
{
301 CallExpression(unsigned NumOperands
, CallInst
*C
,
302 const MemoryAccess
*MemoryLeader
)
303 : MemoryExpression(NumOperands
, ET_Call
, MemoryLeader
), Call(C
) {}
304 CallExpression() = delete;
305 CallExpression(const CallExpression
&) = delete;
306 CallExpression
&operator=(const CallExpression
&) = delete;
307 ~CallExpression() override
;
309 static bool classof(const Expression
*EB
) {
310 return EB
->getExpressionType() == ET_Call
;
314 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
316 OS
<< "ExpressionTypeCall, ";
317 this->BasicExpression::printInternal(OS
, false);
318 OS
<< " represents call at ";
319 Call
->printAsOperand(OS
);
323 class LoadExpression final
: public MemoryExpression
{
329 LoadExpression(unsigned NumOperands
, LoadInst
*L
,
330 const MemoryAccess
*MemoryLeader
)
331 : LoadExpression(ET_Load
, NumOperands
, L
, MemoryLeader
) {}
333 LoadExpression(enum ExpressionType EType
, unsigned NumOperands
, LoadInst
*L
,
334 const MemoryAccess
*MemoryLeader
)
335 : MemoryExpression(NumOperands
, EType
, MemoryLeader
), Load(L
) {
336 Alignment
= L
? L
->getAlignment() : 0;
339 LoadExpression() = delete;
340 LoadExpression(const LoadExpression
&) = delete;
341 LoadExpression
&operator=(const LoadExpression
&) = delete;
342 ~LoadExpression() override
;
344 static bool classof(const Expression
*EB
) {
345 return EB
->getExpressionType() == ET_Load
;
348 LoadInst
*getLoadInst() const { return Load
; }
349 void setLoadInst(LoadInst
*L
) { Load
= L
; }
351 unsigned getAlignment() const { return Alignment
; }
352 void setAlignment(unsigned Align
) { Alignment
= Align
; }
354 bool equals(const Expression
&Other
) const override
;
355 bool exactlyEquals(const Expression
&Other
) const override
{
356 return Expression::exactlyEquals(Other
) &&
357 cast
<LoadExpression
>(Other
).getLoadInst() == getLoadInst();
361 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
363 OS
<< "ExpressionTypeLoad, ";
364 this->BasicExpression::printInternal(OS
, false);
365 OS
<< " represents Load at ";
366 Load
->printAsOperand(OS
);
367 OS
<< " with MemoryLeader " << *getMemoryLeader();
371 class StoreExpression final
: public MemoryExpression
{
377 StoreExpression(unsigned NumOperands
, StoreInst
*S
, Value
*StoredValue
,
378 const MemoryAccess
*MemoryLeader
)
379 : MemoryExpression(NumOperands
, ET_Store
, MemoryLeader
), Store(S
),
380 StoredValue(StoredValue
) {}
381 StoreExpression() = delete;
382 StoreExpression(const StoreExpression
&) = delete;
383 StoreExpression
&operator=(const StoreExpression
&) = delete;
384 ~StoreExpression() override
;
386 static bool classof(const Expression
*EB
) {
387 return EB
->getExpressionType() == ET_Store
;
390 StoreInst
*getStoreInst() const { return Store
; }
391 Value
*getStoredValue() const { return StoredValue
; }
393 bool equals(const Expression
&Other
) const override
;
395 bool exactlyEquals(const Expression
&Other
) const override
{
396 return Expression::exactlyEquals(Other
) &&
397 cast
<StoreExpression
>(Other
).getStoreInst() == getStoreInst();
401 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
403 OS
<< "ExpressionTypeStore, ";
404 this->BasicExpression::printInternal(OS
, false);
405 OS
<< " represents Store " << *Store
;
406 OS
<< " with StoredValue ";
407 StoredValue
->printAsOperand(OS
);
408 OS
<< " and MemoryLeader " << *getMemoryLeader();
412 class AggregateValueExpression final
: public BasicExpression
{
414 unsigned MaxIntOperands
;
415 unsigned NumIntOperands
= 0;
416 unsigned *IntOperands
= nullptr;
419 AggregateValueExpression(unsigned NumOperands
, unsigned NumIntOperands
)
420 : BasicExpression(NumOperands
, ET_AggregateValue
),
421 MaxIntOperands(NumIntOperands
) {}
422 AggregateValueExpression() = delete;
423 AggregateValueExpression(const AggregateValueExpression
&) = delete;
424 AggregateValueExpression
&
425 operator=(const AggregateValueExpression
&) = delete;
426 ~AggregateValueExpression() override
;
428 static bool classof(const Expression
*EB
) {
429 return EB
->getExpressionType() == ET_AggregateValue
;
432 using int_arg_iterator
= unsigned *;
433 using const_int_arg_iterator
= const unsigned *;
435 int_arg_iterator
int_op_begin() { return IntOperands
; }
436 int_arg_iterator
int_op_end() { return IntOperands
+ NumIntOperands
; }
437 const_int_arg_iterator
int_op_begin() const { return IntOperands
; }
438 const_int_arg_iterator
int_op_end() const {
439 return IntOperands
+ NumIntOperands
;
441 unsigned int_op_size() const { return NumIntOperands
; }
442 bool int_op_empty() const { return NumIntOperands
== 0; }
443 void int_op_push_back(unsigned IntOperand
) {
444 assert(NumIntOperands
< MaxIntOperands
&&
445 "Tried to add too many int operands");
446 assert(IntOperands
&& "Operands not allocated before pushing");
447 IntOperands
[NumIntOperands
++] = IntOperand
;
450 virtual void allocateIntOperands(BumpPtrAllocator
&Allocator
) {
451 assert(!IntOperands
&& "Operands already allocated");
452 IntOperands
= Allocator
.Allocate
<unsigned>(MaxIntOperands
);
455 bool equals(const Expression
&Other
) const override
{
456 if (!this->BasicExpression::equals(Other
))
458 const AggregateValueExpression
&OE
= cast
<AggregateValueExpression
>(Other
);
459 return NumIntOperands
== OE
.NumIntOperands
&&
460 std::equal(int_op_begin(), int_op_end(), OE
.int_op_begin());
463 hash_code
getHashValue() const override
{
464 return hash_combine(this->BasicExpression::getHashValue(),
465 hash_combine_range(int_op_begin(), int_op_end()));
469 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
471 OS
<< "ExpressionTypeAggregateValue, ";
472 this->BasicExpression::printInternal(OS
, false);
473 OS
<< ", intoperands = {";
474 for (unsigned i
= 0, e
= int_op_size(); i
!= e
; ++i
) {
475 OS
<< "[" << i
<< "] = " << IntOperands
[i
] << " ";
481 class int_op_inserter
482 : public std::iterator
<std::output_iterator_tag
, void, void, void, void> {
484 using Container
= AggregateValueExpression
;
489 explicit int_op_inserter(AggregateValueExpression
&E
) : AVE(&E
) {}
490 explicit int_op_inserter(AggregateValueExpression
*E
) : AVE(E
) {}
492 int_op_inserter
&operator=(unsigned int val
) {
493 AVE
->int_op_push_back(val
);
496 int_op_inserter
&operator*() { return *this; }
497 int_op_inserter
&operator++() { return *this; }
498 int_op_inserter
&operator++(int) { return *this; }
501 class PHIExpression final
: public BasicExpression
{
506 PHIExpression(unsigned NumOperands
, BasicBlock
*B
)
507 : BasicExpression(NumOperands
, ET_Phi
), BB(B
) {}
508 PHIExpression() = delete;
509 PHIExpression(const PHIExpression
&) = delete;
510 PHIExpression
&operator=(const PHIExpression
&) = delete;
511 ~PHIExpression() override
;
513 static bool classof(const Expression
*EB
) {
514 return EB
->getExpressionType() == ET_Phi
;
517 bool equals(const Expression
&Other
) const override
{
518 if (!this->BasicExpression::equals(Other
))
520 const PHIExpression
&OE
= cast
<PHIExpression
>(Other
);
524 hash_code
getHashValue() const override
{
525 return hash_combine(this->BasicExpression::getHashValue(), BB
);
529 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
531 OS
<< "ExpressionTypePhi, ";
532 this->BasicExpression::printInternal(OS
, false);
537 class DeadExpression final
: public Expression
{
539 DeadExpression() : Expression(ET_Dead
) {}
540 DeadExpression(const DeadExpression
&) = delete;
541 DeadExpression
&operator=(const DeadExpression
&) = delete;
543 static bool classof(const Expression
*E
) {
544 return E
->getExpressionType() == ET_Dead
;
548 class VariableExpression final
: public Expression
{
550 Value
*VariableValue
;
553 VariableExpression(Value
*V
) : Expression(ET_Variable
), VariableValue(V
) {}
554 VariableExpression() = delete;
555 VariableExpression(const VariableExpression
&) = delete;
556 VariableExpression
&operator=(const VariableExpression
&) = delete;
558 static bool classof(const Expression
*EB
) {
559 return EB
->getExpressionType() == ET_Variable
;
562 Value
*getVariableValue() const { return VariableValue
; }
563 void setVariableValue(Value
*V
) { VariableValue
= V
; }
565 bool equals(const Expression
&Other
) const override
{
566 const VariableExpression
&OC
= cast
<VariableExpression
>(Other
);
567 return VariableValue
== OC
.VariableValue
;
570 hash_code
getHashValue() const override
{
571 return hash_combine(this->Expression::getHashValue(),
572 VariableValue
->getType(), VariableValue
);
576 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
578 OS
<< "ExpressionTypeVariable, ";
579 this->Expression::printInternal(OS
, false);
580 OS
<< " variable = " << *VariableValue
;
584 class ConstantExpression final
: public Expression
{
586 Constant
*ConstantValue
= nullptr;
589 ConstantExpression() : Expression(ET_Constant
) {}
590 ConstantExpression(Constant
*constantValue
)
591 : Expression(ET_Constant
), ConstantValue(constantValue
) {}
592 ConstantExpression(const ConstantExpression
&) = delete;
593 ConstantExpression
&operator=(const ConstantExpression
&) = delete;
595 static bool classof(const Expression
*EB
) {
596 return EB
->getExpressionType() == ET_Constant
;
599 Constant
*getConstantValue() const { return ConstantValue
; }
600 void setConstantValue(Constant
*V
) { ConstantValue
= V
; }
602 bool equals(const Expression
&Other
) const override
{
603 const ConstantExpression
&OC
= cast
<ConstantExpression
>(Other
);
604 return ConstantValue
== OC
.ConstantValue
;
607 hash_code
getHashValue() const override
{
608 return hash_combine(this->Expression::getHashValue(),
609 ConstantValue
->getType(), ConstantValue
);
613 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
615 OS
<< "ExpressionTypeConstant, ";
616 this->Expression::printInternal(OS
, false);
617 OS
<< " constant = " << *ConstantValue
;
621 class UnknownExpression final
: public Expression
{
626 UnknownExpression(Instruction
*I
) : Expression(ET_Unknown
), Inst(I
) {}
627 UnknownExpression() = delete;
628 UnknownExpression(const UnknownExpression
&) = delete;
629 UnknownExpression
&operator=(const UnknownExpression
&) = delete;
631 static bool classof(const Expression
*EB
) {
632 return EB
->getExpressionType() == ET_Unknown
;
635 Instruction
*getInstruction() const { return Inst
; }
636 void setInstruction(Instruction
*I
) { Inst
= I
; }
638 bool equals(const Expression
&Other
) const override
{
639 const auto &OU
= cast
<UnknownExpression
>(Other
);
640 return Inst
== OU
.Inst
;
643 hash_code
getHashValue() const override
{
644 return hash_combine(this->Expression::getHashValue(), Inst
);
648 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
650 OS
<< "ExpressionTypeUnknown, ";
651 this->Expression::printInternal(OS
, false);
652 OS
<< " inst = " << *Inst
;
656 } // end namespace GVNExpression
658 } // end namespace llvm
660 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H