[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Analysis / ScalarEvolutionDivision.cpp
blob64e908bdf342e399c18cddbedfc55fce98a05f8b
1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/ScalarEvolutionDivision.h"
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include <cassert>
22 #include <cstdint>
24 namespace llvm {
25 class Type;
28 using namespace llvm;
30 namespace {
32 static inline int sizeOfSCEV(const SCEV *S) {
33 struct FindSCEVSize {
34 int Size = 0;
36 FindSCEVSize() = default;
38 bool follow(const SCEV *S) {
39 ++Size;
40 // Keep looking at all operands of S.
41 return true;
44 bool isDone() const { return false; }
47 FindSCEVSize F;
48 SCEVTraversal<FindSCEVSize> ST(F);
49 ST.visitAll(S);
50 return F.Size;
53 } // namespace
55 // Computes the Quotient and Remainder of the division of Numerator by
56 // Denominator.
57 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
58 const SCEV *Denominator, const SCEV **Quotient,
59 const SCEV **Remainder) {
60 assert(Numerator && Denominator && "Uninitialized SCEV");
62 SCEVDivision D(SE, Numerator, Denominator);
64 // Check for the trivial case here to avoid having to check for it in the
65 // rest of the code.
66 if (Numerator == Denominator) {
67 *Quotient = D.One;
68 *Remainder = D.Zero;
69 return;
72 if (Numerator->isZero()) {
73 *Quotient = D.Zero;
74 *Remainder = D.Zero;
75 return;
78 // A simple case when N/1. The quotient is N.
79 if (Denominator->isOne()) {
80 *Quotient = Numerator;
81 *Remainder = D.Zero;
82 return;
85 // Split the Denominator when it is a product.
86 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
87 const SCEV *Q, *R;
88 *Quotient = Numerator;
89 for (const SCEV *Op : T->operands()) {
90 divide(SE, *Quotient, Op, &Q, &R);
91 *Quotient = Q;
93 // Bail out when the Numerator is not divisible by one of the terms of
94 // the Denominator.
95 if (!R->isZero()) {
96 *Quotient = D.Zero;
97 *Remainder = Numerator;
98 return;
101 *Remainder = D.Zero;
102 return;
105 D.visit(Numerator);
106 *Quotient = D.Quotient;
107 *Remainder = D.Remainder;
110 void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
111 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
112 APInt NumeratorVal = Numerator->getAPInt();
113 APInt DenominatorVal = D->getAPInt();
114 uint32_t NumeratorBW = NumeratorVal.getBitWidth();
115 uint32_t DenominatorBW = DenominatorVal.getBitWidth();
117 if (NumeratorBW > DenominatorBW)
118 DenominatorVal = DenominatorVal.sext(NumeratorBW);
119 else if (NumeratorBW < DenominatorBW)
120 NumeratorVal = NumeratorVal.sext(DenominatorBW);
122 APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
123 APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
124 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
125 Quotient = SE.getConstant(QuotientVal);
126 Remainder = SE.getConstant(RemainderVal);
127 return;
131 void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
132 const SCEV *StartQ, *StartR, *StepQ, *StepR;
133 if (!Numerator->isAffine())
134 return cannotDivide(Numerator);
135 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
136 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
137 // Bail out if the types do not match.
138 Type *Ty = Denominator->getType();
139 if (Ty != StartQ->getType() || Ty != StartR->getType() ||
140 Ty != StepQ->getType() || Ty != StepR->getType())
141 return cannotDivide(Numerator);
142 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
143 Numerator->getNoWrapFlags());
144 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
145 Numerator->getNoWrapFlags());
148 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
149 SmallVector<const SCEV *, 2> Qs, Rs;
150 Type *Ty = Denominator->getType();
152 for (const SCEV *Op : Numerator->operands()) {
153 const SCEV *Q, *R;
154 divide(SE, Op, Denominator, &Q, &R);
156 // Bail out if types do not match.
157 if (Ty != Q->getType() || Ty != R->getType())
158 return cannotDivide(Numerator);
160 Qs.push_back(Q);
161 Rs.push_back(R);
164 if (Qs.size() == 1) {
165 Quotient = Qs[0];
166 Remainder = Rs[0];
167 return;
170 Quotient = SE.getAddExpr(Qs);
171 Remainder = SE.getAddExpr(Rs);
174 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
175 SmallVector<const SCEV *, 2> Qs;
176 Type *Ty = Denominator->getType();
178 bool FoundDenominatorTerm = false;
179 for (const SCEV *Op : Numerator->operands()) {
180 // Bail out if types do not match.
181 if (Ty != Op->getType())
182 return cannotDivide(Numerator);
184 if (FoundDenominatorTerm) {
185 Qs.push_back(Op);
186 continue;
189 // Check whether Denominator divides one of the product operands.
190 const SCEV *Q, *R;
191 divide(SE, Op, Denominator, &Q, &R);
192 if (!R->isZero()) {
193 Qs.push_back(Op);
194 continue;
197 // Bail out if types do not match.
198 if (Ty != Q->getType())
199 return cannotDivide(Numerator);
201 FoundDenominatorTerm = true;
202 Qs.push_back(Q);
205 if (FoundDenominatorTerm) {
206 Remainder = Zero;
207 if (Qs.size() == 1)
208 Quotient = Qs[0];
209 else
210 Quotient = SE.getMulExpr(Qs);
211 return;
214 if (!isa<SCEVUnknown>(Denominator))
215 return cannotDivide(Numerator);
217 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
218 ValueToSCEVMapTy RewriteMap;
219 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
220 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
222 if (Remainder->isZero()) {
223 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
224 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
225 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
226 return;
229 // Quotient is (Numerator - Remainder) divided by Denominator.
230 const SCEV *Q, *R;
231 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
232 // This SCEV does not seem to simplify: fail the division here.
233 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
234 return cannotDivide(Numerator);
235 divide(SE, Diff, Denominator, &Q, &R);
236 if (R != Zero)
237 return cannotDivide(Numerator);
238 Quotient = Q;
241 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
242 const SCEV *Denominator)
243 : SE(S), Denominator(Denominator) {
244 Zero = SE.getZero(Denominator->getType());
245 One = SE.getOne(Denominator->getType());
247 // We generally do not know how to divide Expr by Denominator. We initialize
248 // the division to a "cannot divide" state to simplify the rest of the code.
249 cannotDivide(Numerator);
252 // Convenience function for giving up on the division. We set the quotient to
253 // be equal to zero and the remainder to be equal to the numerator.
254 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
255 Quotient = Zero;
256 Remainder = Numerator;