1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the class that knows how to divide SCEV's.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ErrorHandling.h"
32 static inline int sizeOfSCEV(const SCEV
*S
) {
36 FindSCEVSize() = default;
38 bool follow(const SCEV
*S
) {
40 // Keep looking at all operands of S.
44 bool isDone() const { return false; }
48 SCEVTraversal
<FindSCEVSize
> ST(F
);
55 // Computes the Quotient and Remainder of the division of Numerator by
57 void SCEVDivision::divide(ScalarEvolution
&SE
, const SCEV
*Numerator
,
58 const SCEV
*Denominator
, const SCEV
**Quotient
,
59 const SCEV
**Remainder
) {
60 assert(Numerator
&& Denominator
&& "Uninitialized SCEV");
62 SCEVDivision
D(SE
, Numerator
, Denominator
);
64 // Check for the trivial case here to avoid having to check for it in the
66 if (Numerator
== Denominator
) {
72 if (Numerator
->isZero()) {
78 // A simple case when N/1. The quotient is N.
79 if (Denominator
->isOne()) {
80 *Quotient
= Numerator
;
85 // Split the Denominator when it is a product.
86 if (const SCEVMulExpr
*T
= dyn_cast
<SCEVMulExpr
>(Denominator
)) {
88 *Quotient
= Numerator
;
89 for (const SCEV
*Op
: T
->operands()) {
90 divide(SE
, *Quotient
, Op
, &Q
, &R
);
93 // Bail out when the Numerator is not divisible by one of the terms of
97 *Remainder
= Numerator
;
106 *Quotient
= D
.Quotient
;
107 *Remainder
= D
.Remainder
;
110 void SCEVDivision::visitConstant(const SCEVConstant
*Numerator
) {
111 if (const SCEVConstant
*D
= dyn_cast
<SCEVConstant
>(Denominator
)) {
112 APInt NumeratorVal
= Numerator
->getAPInt();
113 APInt DenominatorVal
= D
->getAPInt();
114 uint32_t NumeratorBW
= NumeratorVal
.getBitWidth();
115 uint32_t DenominatorBW
= DenominatorVal
.getBitWidth();
117 if (NumeratorBW
> DenominatorBW
)
118 DenominatorVal
= DenominatorVal
.sext(NumeratorBW
);
119 else if (NumeratorBW
< DenominatorBW
)
120 NumeratorVal
= NumeratorVal
.sext(DenominatorBW
);
122 APInt
QuotientVal(NumeratorVal
.getBitWidth(), 0);
123 APInt
RemainderVal(NumeratorVal
.getBitWidth(), 0);
124 APInt::sdivrem(NumeratorVal
, DenominatorVal
, QuotientVal
, RemainderVal
);
125 Quotient
= SE
.getConstant(QuotientVal
);
126 Remainder
= SE
.getConstant(RemainderVal
);
131 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr
*Numerator
) {
132 const SCEV
*StartQ
, *StartR
, *StepQ
, *StepR
;
133 if (!Numerator
->isAffine())
134 return cannotDivide(Numerator
);
135 divide(SE
, Numerator
->getStart(), Denominator
, &StartQ
, &StartR
);
136 divide(SE
, Numerator
->getStepRecurrence(SE
), Denominator
, &StepQ
, &StepR
);
137 // Bail out if the types do not match.
138 Type
*Ty
= Denominator
->getType();
139 if (Ty
!= StartQ
->getType() || Ty
!= StartR
->getType() ||
140 Ty
!= StepQ
->getType() || Ty
!= StepR
->getType())
141 return cannotDivide(Numerator
);
142 Quotient
= SE
.getAddRecExpr(StartQ
, StepQ
, Numerator
->getLoop(),
143 Numerator
->getNoWrapFlags());
144 Remainder
= SE
.getAddRecExpr(StartR
, StepR
, Numerator
->getLoop(),
145 Numerator
->getNoWrapFlags());
148 void SCEVDivision::visitAddExpr(const SCEVAddExpr
*Numerator
) {
149 SmallVector
<const SCEV
*, 2> Qs
, Rs
;
150 Type
*Ty
= Denominator
->getType();
152 for (const SCEV
*Op
: Numerator
->operands()) {
154 divide(SE
, Op
, Denominator
, &Q
, &R
);
156 // Bail out if types do not match.
157 if (Ty
!= Q
->getType() || Ty
!= R
->getType())
158 return cannotDivide(Numerator
);
164 if (Qs
.size() == 1) {
170 Quotient
= SE
.getAddExpr(Qs
);
171 Remainder
= SE
.getAddExpr(Rs
);
174 void SCEVDivision::visitMulExpr(const SCEVMulExpr
*Numerator
) {
175 SmallVector
<const SCEV
*, 2> Qs
;
176 Type
*Ty
= Denominator
->getType();
178 bool FoundDenominatorTerm
= false;
179 for (const SCEV
*Op
: Numerator
->operands()) {
180 // Bail out if types do not match.
181 if (Ty
!= Op
->getType())
182 return cannotDivide(Numerator
);
184 if (FoundDenominatorTerm
) {
189 // Check whether Denominator divides one of the product operands.
191 divide(SE
, Op
, Denominator
, &Q
, &R
);
197 // Bail out if types do not match.
198 if (Ty
!= Q
->getType())
199 return cannotDivide(Numerator
);
201 FoundDenominatorTerm
= true;
205 if (FoundDenominatorTerm
) {
210 Quotient
= SE
.getMulExpr(Qs
);
214 if (!isa
<SCEVUnknown
>(Denominator
))
215 return cannotDivide(Numerator
);
217 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
218 ValueToSCEVMapTy RewriteMap
;
219 RewriteMap
[cast
<SCEVUnknown
>(Denominator
)->getValue()] = Zero
;
220 Remainder
= SCEVParameterRewriter::rewrite(Numerator
, SE
, RewriteMap
);
222 if (Remainder
->isZero()) {
223 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
224 RewriteMap
[cast
<SCEVUnknown
>(Denominator
)->getValue()] = One
;
225 Quotient
= SCEVParameterRewriter::rewrite(Numerator
, SE
, RewriteMap
);
229 // Quotient is (Numerator - Remainder) divided by Denominator.
231 const SCEV
*Diff
= SE
.getMinusSCEV(Numerator
, Remainder
);
232 // This SCEV does not seem to simplify: fail the division here.
233 if (sizeOfSCEV(Diff
) > sizeOfSCEV(Numerator
))
234 return cannotDivide(Numerator
);
235 divide(SE
, Diff
, Denominator
, &Q
, &R
);
237 return cannotDivide(Numerator
);
241 SCEVDivision::SCEVDivision(ScalarEvolution
&S
, const SCEV
*Numerator
,
242 const SCEV
*Denominator
)
243 : SE(S
), Denominator(Denominator
) {
244 Zero
= SE
.getZero(Denominator
->getType());
245 One
= SE
.getOne(Denominator
->getType());
247 // We generally do not know how to divide Expr by Denominator. We initialize
248 // the division to a "cannot divide" state to simplify the rest of the code.
249 cannotDivide(Numerator
);
252 // Convenience function for giving up on the division. We set the quotient to
253 // be equal to zero and the remainder to be equal to the numerator.
254 void SCEVDivision::cannotDivide(const SCEV
*Numerator
) {
256 Remainder
= Numerator
;