1 //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 /// The header file for the GVN pass that contains expression handling
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Analysis/MemorySSA.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Allocator.h"
26 #include "llvm/Support/ArrayRecycler.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/raw_ostream.h"
40 namespace GVNExpression
{
64 mutable hash_code HashVal
= 0;
67 Expression(ExpressionType ET
= ET_Base
, unsigned O
= ~2U)
68 : EType(ET
), Opcode(O
) {}
69 Expression(const Expression
&) = delete;
70 Expression
&operator=(const Expression
&) = delete;
71 virtual ~Expression();
73 static unsigned getEmptyKey() { return ~0U; }
74 static unsigned getTombstoneKey() { return ~1U; }
76 bool operator!=(const Expression
&Other
) const { return !(*this == Other
); }
77 bool operator==(const Expression
&Other
) const {
78 if (getOpcode() != Other
.getOpcode())
80 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
82 // Compare the expression type for anything but load and store.
83 // For load and store we set the opcode to zero to make them equal.
84 if (getExpressionType() != ET_Load
&& getExpressionType() != ET_Store
&&
85 getExpressionType() != Other
.getExpressionType())
91 hash_code
getComputedHash() const {
92 // It's theoretically possible for a thing to hash to zero. In that case,
93 // we will just compute the hash a few extra times, which is no worse that
94 // we did before, which was to compute it always.
95 if (static_cast<unsigned>(HashVal
) == 0)
96 HashVal
= getHashValue();
100 virtual bool equals(const Expression
&Other
) const { return true; }
102 // Return true if the two expressions are exactly the same, including the
103 // normally ignored fields.
104 virtual bool exactlyEquals(const Expression
&Other
) const {
105 return getExpressionType() == Other
.getExpressionType() && equals(Other
);
108 unsigned getOpcode() const { return Opcode
; }
109 void setOpcode(unsigned opcode
) { Opcode
= opcode
; }
110 ExpressionType
getExpressionType() const { return EType
; }
112 // We deliberately leave the expression type out of the hash value.
113 virtual hash_code
getHashValue() const { return getOpcode(); }
116 virtual void printInternal(raw_ostream
&OS
, bool PrintEType
) const {
118 OS
<< "etype = " << getExpressionType() << ",";
119 OS
<< "opcode = " << getOpcode() << ", ";
122 void print(raw_ostream
&OS
) const {
124 printInternal(OS
, true);
128 LLVM_DUMP_METHOD
void dump() const;
131 inline raw_ostream
&operator<<(raw_ostream
&OS
, const Expression
&E
) {
136 class BasicExpression
: public Expression
{
138 using RecyclerType
= ArrayRecycler
<Value
*>;
139 using RecyclerCapacity
= RecyclerType::Capacity
;
141 Value
**Operands
= nullptr;
142 unsigned MaxOperands
;
143 unsigned NumOperands
= 0;
144 Type
*ValueType
= nullptr;
147 BasicExpression(unsigned NumOperands
)
148 : BasicExpression(NumOperands
, ET_Basic
) {}
149 BasicExpression(unsigned NumOperands
, ExpressionType ET
)
150 : Expression(ET
), MaxOperands(NumOperands
) {}
151 BasicExpression() = delete;
152 BasicExpression(const BasicExpression
&) = delete;
153 BasicExpression
&operator=(const BasicExpression
&) = delete;
154 ~BasicExpression() override
;
156 static bool classof(const Expression
*EB
) {
157 ExpressionType ET
= EB
->getExpressionType();
158 return ET
> ET_BasicStart
&& ET
< ET_BasicEnd
;
161 /// Swap two operands. Used during GVN to put commutative operands in
163 void swapOperands(unsigned First
, unsigned Second
) {
164 std::swap(Operands
[First
], Operands
[Second
]);
167 Value
*getOperand(unsigned N
) const {
168 assert(Operands
&& "Operands not allocated");
169 assert(N
< NumOperands
&& "Operand out of range");
173 void setOperand(unsigned N
, Value
*V
) {
174 assert(Operands
&& "Operands not allocated before setting");
175 assert(N
< NumOperands
&& "Operand out of range");
179 unsigned getNumOperands() const { return NumOperands
; }
181 using op_iterator
= Value
**;
182 using const_op_iterator
= Value
*const *;
184 op_iterator
op_begin() { return Operands
; }
185 op_iterator
op_end() { return Operands
+ NumOperands
; }
186 const_op_iterator
op_begin() const { return Operands
; }
187 const_op_iterator
op_end() const { return Operands
+ NumOperands
; }
188 iterator_range
<op_iterator
> operands() {
189 return iterator_range
<op_iterator
>(op_begin(), op_end());
191 iterator_range
<const_op_iterator
> operands() const {
192 return iterator_range
<const_op_iterator
>(op_begin(), op_end());
195 void op_push_back(Value
*Arg
) {
196 assert(NumOperands
< MaxOperands
&& "Tried to add too many operands");
197 assert(Operands
&& "Operandss not allocated before pushing");
198 Operands
[NumOperands
++] = Arg
;
200 bool op_empty() const { return getNumOperands() == 0; }
202 void allocateOperands(RecyclerType
&Recycler
, BumpPtrAllocator
&Allocator
) {
203 assert(!Operands
&& "Operands already allocated");
204 Operands
= Recycler
.allocate(RecyclerCapacity::get(MaxOperands
), Allocator
);
206 void deallocateOperands(RecyclerType
&Recycler
) {
207 Recycler
.deallocate(RecyclerCapacity::get(MaxOperands
), Operands
);
210 void setType(Type
*T
) { ValueType
= T
; }
211 Type
*getType() const { return ValueType
; }
213 bool equals(const Expression
&Other
) const override
{
214 if (getOpcode() != Other
.getOpcode())
217 const auto &OE
= cast
<BasicExpression
>(Other
);
218 return getType() == OE
.getType() && NumOperands
== OE
.NumOperands
&&
219 std::equal(op_begin(), op_end(), OE
.op_begin());
222 hash_code
getHashValue() const override
{
223 return hash_combine(this->Expression::getHashValue(), ValueType
,
224 hash_combine_range(op_begin(), op_end()));
228 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
230 OS
<< "ExpressionTypeBasic, ";
232 this->Expression::printInternal(OS
, false);
233 OS
<< "operands = {";
234 for (unsigned i
= 0, e
= getNumOperands(); i
!= e
; ++i
) {
235 OS
<< "[" << i
<< "] = ";
236 Operands
[i
]->printAsOperand(OS
);
244 : public std::iterator
<std::output_iterator_tag
, void, void, void, void> {
246 using Container
= BasicExpression
;
251 explicit op_inserter(BasicExpression
&E
) : BE(&E
) {}
252 explicit op_inserter(BasicExpression
*E
) : BE(E
) {}
254 op_inserter
&operator=(Value
*val
) {
255 BE
->op_push_back(val
);
258 op_inserter
&operator*() { return *this; }
259 op_inserter
&operator++() { return *this; }
260 op_inserter
&operator++(int) { return *this; }
263 class MemoryExpression
: public BasicExpression
{
265 const MemoryAccess
*MemoryLeader
;
268 MemoryExpression(unsigned NumOperands
, enum ExpressionType EType
,
269 const MemoryAccess
*MemoryLeader
)
270 : BasicExpression(NumOperands
, EType
), MemoryLeader(MemoryLeader
) {}
271 MemoryExpression() = delete;
272 MemoryExpression(const MemoryExpression
&) = delete;
273 MemoryExpression
&operator=(const MemoryExpression
&) = delete;
275 static bool classof(const Expression
*EB
) {
276 return EB
->getExpressionType() > ET_MemoryStart
&&
277 EB
->getExpressionType() < ET_MemoryEnd
;
280 hash_code
getHashValue() const override
{
281 return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader
);
284 bool equals(const Expression
&Other
) const override
{
285 if (!this->BasicExpression::equals(Other
))
287 const MemoryExpression
&OtherMCE
= cast
<MemoryExpression
>(Other
);
289 return MemoryLeader
== OtherMCE
.MemoryLeader
;
292 const MemoryAccess
*getMemoryLeader() const { return MemoryLeader
; }
293 void setMemoryLeader(const MemoryAccess
*ML
) { MemoryLeader
= ML
; }
296 class CallExpression final
: public MemoryExpression
{
301 CallExpression(unsigned NumOperands
, CallInst
*C
,
302 const MemoryAccess
*MemoryLeader
)
303 : MemoryExpression(NumOperands
, ET_Call
, MemoryLeader
), Call(C
) {}
304 CallExpression() = delete;
305 CallExpression(const CallExpression
&) = delete;
306 CallExpression
&operator=(const CallExpression
&) = delete;
307 ~CallExpression() override
;
309 static bool classof(const Expression
*EB
) {
310 return EB
->getExpressionType() == ET_Call
;
314 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
316 OS
<< "ExpressionTypeCall, ";
317 this->BasicExpression::printInternal(OS
, false);
318 OS
<< " represents call at ";
319 Call
->printAsOperand(OS
);
323 class LoadExpression final
: public MemoryExpression
{
326 MaybeAlign Alignment
;
329 LoadExpression(unsigned NumOperands
, LoadInst
*L
,
330 const MemoryAccess
*MemoryLeader
)
331 : LoadExpression(ET_Load
, NumOperands
, L
, MemoryLeader
) {}
333 LoadExpression(enum ExpressionType EType
, unsigned NumOperands
, LoadInst
*L
,
334 const MemoryAccess
*MemoryLeader
)
335 : MemoryExpression(NumOperands
, EType
, MemoryLeader
), Load(L
) {
337 Alignment
= MaybeAlign(L
->getAlignment());
340 LoadExpression() = delete;
341 LoadExpression(const LoadExpression
&) = delete;
342 LoadExpression
&operator=(const LoadExpression
&) = delete;
343 ~LoadExpression() override
;
345 static bool classof(const Expression
*EB
) {
346 return EB
->getExpressionType() == ET_Load
;
349 LoadInst
*getLoadInst() const { return Load
; }
350 void setLoadInst(LoadInst
*L
) { Load
= L
; }
352 MaybeAlign
getAlignment() const { return Alignment
; }
353 void setAlignment(MaybeAlign Align
) { Alignment
= Align
; }
355 bool equals(const Expression
&Other
) const override
;
356 bool exactlyEquals(const Expression
&Other
) const override
{
357 return Expression::exactlyEquals(Other
) &&
358 cast
<LoadExpression
>(Other
).getLoadInst() == getLoadInst();
362 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
364 OS
<< "ExpressionTypeLoad, ";
365 this->BasicExpression::printInternal(OS
, false);
366 OS
<< " represents Load at ";
367 Load
->printAsOperand(OS
);
368 OS
<< " with MemoryLeader " << *getMemoryLeader();
372 class StoreExpression final
: public MemoryExpression
{
378 StoreExpression(unsigned NumOperands
, StoreInst
*S
, Value
*StoredValue
,
379 const MemoryAccess
*MemoryLeader
)
380 : MemoryExpression(NumOperands
, ET_Store
, MemoryLeader
), Store(S
),
381 StoredValue(StoredValue
) {}
382 StoreExpression() = delete;
383 StoreExpression(const StoreExpression
&) = delete;
384 StoreExpression
&operator=(const StoreExpression
&) = delete;
385 ~StoreExpression() override
;
387 static bool classof(const Expression
*EB
) {
388 return EB
->getExpressionType() == ET_Store
;
391 StoreInst
*getStoreInst() const { return Store
; }
392 Value
*getStoredValue() const { return StoredValue
; }
394 bool equals(const Expression
&Other
) const override
;
396 bool exactlyEquals(const Expression
&Other
) const override
{
397 return Expression::exactlyEquals(Other
) &&
398 cast
<StoreExpression
>(Other
).getStoreInst() == getStoreInst();
402 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
404 OS
<< "ExpressionTypeStore, ";
405 this->BasicExpression::printInternal(OS
, false);
406 OS
<< " represents Store " << *Store
;
407 OS
<< " with StoredValue ";
408 StoredValue
->printAsOperand(OS
);
409 OS
<< " and MemoryLeader " << *getMemoryLeader();
413 class AggregateValueExpression final
: public BasicExpression
{
415 unsigned MaxIntOperands
;
416 unsigned NumIntOperands
= 0;
417 unsigned *IntOperands
= nullptr;
420 AggregateValueExpression(unsigned NumOperands
, unsigned NumIntOperands
)
421 : BasicExpression(NumOperands
, ET_AggregateValue
),
422 MaxIntOperands(NumIntOperands
) {}
423 AggregateValueExpression() = delete;
424 AggregateValueExpression(const AggregateValueExpression
&) = delete;
425 AggregateValueExpression
&
426 operator=(const AggregateValueExpression
&) = delete;
427 ~AggregateValueExpression() override
;
429 static bool classof(const Expression
*EB
) {
430 return EB
->getExpressionType() == ET_AggregateValue
;
433 using int_arg_iterator
= unsigned *;
434 using const_int_arg_iterator
= const unsigned *;
436 int_arg_iterator
int_op_begin() { return IntOperands
; }
437 int_arg_iterator
int_op_end() { return IntOperands
+ NumIntOperands
; }
438 const_int_arg_iterator
int_op_begin() const { return IntOperands
; }
439 const_int_arg_iterator
int_op_end() const {
440 return IntOperands
+ NumIntOperands
;
442 unsigned int_op_size() const { return NumIntOperands
; }
443 bool int_op_empty() const { return NumIntOperands
== 0; }
444 void int_op_push_back(unsigned IntOperand
) {
445 assert(NumIntOperands
< MaxIntOperands
&&
446 "Tried to add too many int operands");
447 assert(IntOperands
&& "Operands not allocated before pushing");
448 IntOperands
[NumIntOperands
++] = IntOperand
;
451 virtual void allocateIntOperands(BumpPtrAllocator
&Allocator
) {
452 assert(!IntOperands
&& "Operands already allocated");
453 IntOperands
= Allocator
.Allocate
<unsigned>(MaxIntOperands
);
456 bool equals(const Expression
&Other
) const override
{
457 if (!this->BasicExpression::equals(Other
))
459 const AggregateValueExpression
&OE
= cast
<AggregateValueExpression
>(Other
);
460 return NumIntOperands
== OE
.NumIntOperands
&&
461 std::equal(int_op_begin(), int_op_end(), OE
.int_op_begin());
464 hash_code
getHashValue() const override
{
465 return hash_combine(this->BasicExpression::getHashValue(),
466 hash_combine_range(int_op_begin(), int_op_end()));
470 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
472 OS
<< "ExpressionTypeAggregateValue, ";
473 this->BasicExpression::printInternal(OS
, false);
474 OS
<< ", intoperands = {";
475 for (unsigned i
= 0, e
= int_op_size(); i
!= e
; ++i
) {
476 OS
<< "[" << i
<< "] = " << IntOperands
[i
] << " ";
482 class int_op_inserter
483 : public std::iterator
<std::output_iterator_tag
, void, void, void, void> {
485 using Container
= AggregateValueExpression
;
490 explicit int_op_inserter(AggregateValueExpression
&E
) : AVE(&E
) {}
491 explicit int_op_inserter(AggregateValueExpression
*E
) : AVE(E
) {}
493 int_op_inserter
&operator=(unsigned int val
) {
494 AVE
->int_op_push_back(val
);
497 int_op_inserter
&operator*() { return *this; }
498 int_op_inserter
&operator++() { return *this; }
499 int_op_inserter
&operator++(int) { return *this; }
502 class PHIExpression final
: public BasicExpression
{
507 PHIExpression(unsigned NumOperands
, BasicBlock
*B
)
508 : BasicExpression(NumOperands
, ET_Phi
), BB(B
) {}
509 PHIExpression() = delete;
510 PHIExpression(const PHIExpression
&) = delete;
511 PHIExpression
&operator=(const PHIExpression
&) = delete;
512 ~PHIExpression() override
;
514 static bool classof(const Expression
*EB
) {
515 return EB
->getExpressionType() == ET_Phi
;
518 bool equals(const Expression
&Other
) const override
{
519 if (!this->BasicExpression::equals(Other
))
521 const PHIExpression
&OE
= cast
<PHIExpression
>(Other
);
525 hash_code
getHashValue() const override
{
526 return hash_combine(this->BasicExpression::getHashValue(), BB
);
530 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
532 OS
<< "ExpressionTypePhi, ";
533 this->BasicExpression::printInternal(OS
, false);
538 class DeadExpression final
: public Expression
{
540 DeadExpression() : Expression(ET_Dead
) {}
541 DeadExpression(const DeadExpression
&) = delete;
542 DeadExpression
&operator=(const DeadExpression
&) = delete;
544 static bool classof(const Expression
*E
) {
545 return E
->getExpressionType() == ET_Dead
;
549 class VariableExpression final
: public Expression
{
551 Value
*VariableValue
;
554 VariableExpression(Value
*V
) : Expression(ET_Variable
), VariableValue(V
) {}
555 VariableExpression() = delete;
556 VariableExpression(const VariableExpression
&) = delete;
557 VariableExpression
&operator=(const VariableExpression
&) = delete;
559 static bool classof(const Expression
*EB
) {
560 return EB
->getExpressionType() == ET_Variable
;
563 Value
*getVariableValue() const { return VariableValue
; }
564 void setVariableValue(Value
*V
) { VariableValue
= V
; }
566 bool equals(const Expression
&Other
) const override
{
567 const VariableExpression
&OC
= cast
<VariableExpression
>(Other
);
568 return VariableValue
== OC
.VariableValue
;
571 hash_code
getHashValue() const override
{
572 return hash_combine(this->Expression::getHashValue(),
573 VariableValue
->getType(), VariableValue
);
577 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
579 OS
<< "ExpressionTypeVariable, ";
580 this->Expression::printInternal(OS
, false);
581 OS
<< " variable = " << *VariableValue
;
585 class ConstantExpression final
: public Expression
{
587 Constant
*ConstantValue
= nullptr;
590 ConstantExpression() : Expression(ET_Constant
) {}
591 ConstantExpression(Constant
*constantValue
)
592 : Expression(ET_Constant
), ConstantValue(constantValue
) {}
593 ConstantExpression(const ConstantExpression
&) = delete;
594 ConstantExpression
&operator=(const ConstantExpression
&) = delete;
596 static bool classof(const Expression
*EB
) {
597 return EB
->getExpressionType() == ET_Constant
;
600 Constant
*getConstantValue() const { return ConstantValue
; }
601 void setConstantValue(Constant
*V
) { ConstantValue
= V
; }
603 bool equals(const Expression
&Other
) const override
{
604 const ConstantExpression
&OC
= cast
<ConstantExpression
>(Other
);
605 return ConstantValue
== OC
.ConstantValue
;
608 hash_code
getHashValue() const override
{
609 return hash_combine(this->Expression::getHashValue(),
610 ConstantValue
->getType(), ConstantValue
);
614 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
616 OS
<< "ExpressionTypeConstant, ";
617 this->Expression::printInternal(OS
, false);
618 OS
<< " constant = " << *ConstantValue
;
622 class UnknownExpression final
: public Expression
{
627 UnknownExpression(Instruction
*I
) : Expression(ET_Unknown
), Inst(I
) {}
628 UnknownExpression() = delete;
629 UnknownExpression(const UnknownExpression
&) = delete;
630 UnknownExpression
&operator=(const UnknownExpression
&) = delete;
632 static bool classof(const Expression
*EB
) {
633 return EB
->getExpressionType() == ET_Unknown
;
636 Instruction
*getInstruction() const { return Inst
; }
637 void setInstruction(Instruction
*I
) { Inst
= I
; }
639 bool equals(const Expression
&Other
) const override
{
640 const auto &OU
= cast
<UnknownExpression
>(Other
);
641 return Inst
== OU
.Inst
;
644 hash_code
getHashValue() const override
{
645 return hash_combine(this->Expression::getHashValue(), Inst
);
649 void printInternal(raw_ostream
&OS
, bool PrintEType
) const override
{
651 OS
<< "ExpressionTypeUnknown, ";
652 this->Expression::printInternal(OS
, false);
653 OS
<< " inst = " << *Inst
;
657 } // end namespace GVNExpression
659 } // end namespace llvm
661 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H