1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the class that knows how to divide SCEV's.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Support/Casting.h"
30 static inline int sizeOfSCEV(const SCEV
*S
) {
34 FindSCEVSize() = default;
36 bool follow(const SCEV
*S
) {
38 // Keep looking at all operands of S.
42 bool isDone() const { return false; }
46 SCEVTraversal
<FindSCEVSize
> ST(F
);
53 // Computes the Quotient and Remainder of the division of Numerator by
55 void SCEVDivision::divide(ScalarEvolution
&SE
, const SCEV
*Numerator
,
56 const SCEV
*Denominator
, const SCEV
**Quotient
,
57 const SCEV
**Remainder
) {
58 assert(Numerator
&& Denominator
&& "Uninitialized SCEV");
60 SCEVDivision
D(SE
, Numerator
, Denominator
);
62 // Check for the trivial case here to avoid having to check for it in the
64 if (Numerator
== Denominator
) {
70 if (Numerator
->isZero()) {
76 // A simple case when N/1. The quotient is N.
77 if (Denominator
->isOne()) {
78 *Quotient
= Numerator
;
83 // Split the Denominator when it is a product.
84 if (const SCEVMulExpr
*T
= dyn_cast
<SCEVMulExpr
>(Denominator
)) {
86 *Quotient
= Numerator
;
87 for (const SCEV
*Op
: T
->operands()) {
88 divide(SE
, *Quotient
, Op
, &Q
, &R
);
91 // Bail out when the Numerator is not divisible by one of the terms of
95 *Remainder
= Numerator
;
104 *Quotient
= D
.Quotient
;
105 *Remainder
= D
.Remainder
;
108 void SCEVDivision::visitConstant(const SCEVConstant
*Numerator
) {
109 if (const SCEVConstant
*D
= dyn_cast
<SCEVConstant
>(Denominator
)) {
110 APInt NumeratorVal
= Numerator
->getAPInt();
111 APInt DenominatorVal
= D
->getAPInt();
112 uint32_t NumeratorBW
= NumeratorVal
.getBitWidth();
113 uint32_t DenominatorBW
= DenominatorVal
.getBitWidth();
115 if (NumeratorBW
> DenominatorBW
)
116 DenominatorVal
= DenominatorVal
.sext(NumeratorBW
);
117 else if (NumeratorBW
< DenominatorBW
)
118 NumeratorVal
= NumeratorVal
.sext(DenominatorBW
);
120 APInt
QuotientVal(NumeratorVal
.getBitWidth(), 0);
121 APInt
RemainderVal(NumeratorVal
.getBitWidth(), 0);
122 APInt::sdivrem(NumeratorVal
, DenominatorVal
, QuotientVal
, RemainderVal
);
123 Quotient
= SE
.getConstant(QuotientVal
);
124 Remainder
= SE
.getConstant(RemainderVal
);
129 void SCEVDivision::visitVScale(const SCEVVScale
*Numerator
) {
130 return cannotDivide(Numerator
);
133 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr
*Numerator
) {
134 const SCEV
*StartQ
, *StartR
, *StepQ
, *StepR
;
135 if (!Numerator
->isAffine())
136 return cannotDivide(Numerator
);
137 divide(SE
, Numerator
->getStart(), Denominator
, &StartQ
, &StartR
);
138 divide(SE
, Numerator
->getStepRecurrence(SE
), Denominator
, &StepQ
, &StepR
);
139 // Bail out if the types do not match.
140 Type
*Ty
= Denominator
->getType();
141 if (Ty
!= StartQ
->getType() || Ty
!= StartR
->getType() ||
142 Ty
!= StepQ
->getType() || Ty
!= StepR
->getType())
143 return cannotDivide(Numerator
);
144 Quotient
= SE
.getAddRecExpr(StartQ
, StepQ
, Numerator
->getLoop(),
145 Numerator
->getNoWrapFlags());
146 Remainder
= SE
.getAddRecExpr(StartR
, StepR
, Numerator
->getLoop(),
147 Numerator
->getNoWrapFlags());
150 void SCEVDivision::visitAddExpr(const SCEVAddExpr
*Numerator
) {
151 SmallVector
<const SCEV
*, 2> Qs
, Rs
;
152 Type
*Ty
= Denominator
->getType();
154 for (const SCEV
*Op
: Numerator
->operands()) {
156 divide(SE
, Op
, Denominator
, &Q
, &R
);
158 // Bail out if types do not match.
159 if (Ty
!= Q
->getType() || Ty
!= R
->getType())
160 return cannotDivide(Numerator
);
166 if (Qs
.size() == 1) {
172 Quotient
= SE
.getAddExpr(Qs
);
173 Remainder
= SE
.getAddExpr(Rs
);
176 void SCEVDivision::visitMulExpr(const SCEVMulExpr
*Numerator
) {
177 SmallVector
<const SCEV
*, 2> Qs
;
178 Type
*Ty
= Denominator
->getType();
180 bool FoundDenominatorTerm
= false;
181 for (const SCEV
*Op
: Numerator
->operands()) {
182 // Bail out if types do not match.
183 if (Ty
!= Op
->getType())
184 return cannotDivide(Numerator
);
186 if (FoundDenominatorTerm
) {
191 // Check whether Denominator divides one of the product operands.
193 divide(SE
, Op
, Denominator
, &Q
, &R
);
199 // Bail out if types do not match.
200 if (Ty
!= Q
->getType())
201 return cannotDivide(Numerator
);
203 FoundDenominatorTerm
= true;
207 if (FoundDenominatorTerm
) {
212 Quotient
= SE
.getMulExpr(Qs
);
216 if (!isa
<SCEVUnknown
>(Denominator
))
217 return cannotDivide(Numerator
);
219 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
220 ValueToSCEVMapTy RewriteMap
;
221 RewriteMap
[cast
<SCEVUnknown
>(Denominator
)->getValue()] = Zero
;
222 Remainder
= SCEVParameterRewriter::rewrite(Numerator
, SE
, RewriteMap
);
224 if (Remainder
->isZero()) {
225 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
226 RewriteMap
[cast
<SCEVUnknown
>(Denominator
)->getValue()] = One
;
227 Quotient
= SCEVParameterRewriter::rewrite(Numerator
, SE
, RewriteMap
);
231 // Quotient is (Numerator - Remainder) divided by Denominator.
233 const SCEV
*Diff
= SE
.getMinusSCEV(Numerator
, Remainder
);
234 // This SCEV does not seem to simplify: fail the division here.
235 if (sizeOfSCEV(Diff
) > sizeOfSCEV(Numerator
))
236 return cannotDivide(Numerator
);
237 divide(SE
, Diff
, Denominator
, &Q
, &R
);
239 return cannotDivide(Numerator
);
243 SCEVDivision::SCEVDivision(ScalarEvolution
&S
, const SCEV
*Numerator
,
244 const SCEV
*Denominator
)
245 : SE(S
), Denominator(Denominator
) {
246 Zero
= SE
.getZero(Denominator
->getType());
247 One
= SE
.getOne(Denominator
->getType());
249 // We generally do not know how to divide Expr by Denominator. We initialize
250 // the division to a "cannot divide" state to simplify the rest of the code.
251 cannotDivide(Numerator
);
254 // Convenience function for giving up on the division. We set the quotient to
255 // be equal to zero and the remainder to be equal to the numerator.
256 void SCEVDivision::cannotDivide(const SCEV
*Numerator
) {
258 Remainder
= Numerator
;