Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / unittests / Support / DivisionByConstantTest.cpp
blobc0b708e277f204482a779f3b7f029ee30cdfd6de
1 //===- llvm/unittest/Support/DivisionByConstantTest.cpp -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/Support/DivisionByConstantInfo.h"
11 #include "gtest/gtest.h"
12 #include <array>
13 #include <optional>
15 using namespace llvm;
17 namespace {
19 template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
20 APInt N(Bits, 0);
21 do {
22 TestFn(N);
23 } while (++N != 0);
26 APInt MULHS(APInt X, APInt Y) {
27 unsigned Bits = X.getBitWidth();
28 unsigned WideBits = 2 * Bits;
29 return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
32 APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
33 SignedDivisionByConstantInfo Magics) {
34 unsigned Bits = Numerator.getBitWidth();
36 APInt Factor(Bits, 0);
37 APInt ShiftMask(Bits, -1);
38 if (Divisor.isOne() || Divisor.isAllOnes()) {
39 // If d is +1/-1, we just multiply the numerator by +1/-1.
40 Factor = Divisor.getSExtValue();
41 Magics.Magic = 0;
42 Magics.ShiftAmount = 0;
43 ShiftMask = 0;
44 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) {
45 // If d > 0 and m < 0, add the numerator.
46 Factor = 1;
47 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) {
48 // If d < 0 and m > 0, subtract the numerator.
49 Factor = -1;
52 // Multiply the numerator by the magic value.
53 APInt Q = MULHS(Numerator, Magics.Magic);
55 // (Optionally) Add/subtract the numerator using Factor.
56 Factor = Numerator * Factor;
57 Q = Q + Factor;
59 // Shift right algebraic by shift value.
60 Q = Q.ashr(Magics.ShiftAmount);
62 // Extract the sign bit, mask it and add it to the quotient.
63 unsigned SignShift = Bits - 1;
64 APInt T = Q.lshr(SignShift);
65 T = T & ShiftMask;
66 return Q + T;
69 TEST(SignedDivisionByConstantTest, Test) {
70 for (unsigned Bits = 1; Bits <= 32; ++Bits) {
71 if (Bits < 3)
72 continue; // Not supported by `SignedDivisionByConstantInfo::get()`.
73 if (Bits > 12)
74 continue; // Unreasonably slow.
75 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
76 if (Divisor.isZero())
77 return; // Division by zero is undefined behavior.
78 SignedDivisionByConstantInfo Magics;
79 if (!(Divisor.isOne() || Divisor.isAllOnes()))
80 Magics = SignedDivisionByConstantInfo::get(Divisor);
81 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
82 if (Numerator.isMinSignedValue() && Divisor.isAllOnes())
83 return; // Overflow is undefined behavior.
84 APInt NativeResult = Numerator.sdiv(Divisor);
85 APInt MagicResult = SignedDivideUsingMagic(Numerator, Divisor, Magics);
86 ASSERT_EQ(MagicResult, NativeResult)
87 << " ... given the operation: srem i" << Bits << " " << Numerator
88 << ", " << Divisor;
89 });
90 });
94 APInt MULHU(APInt X, APInt Y) {
95 unsigned Bits = X.getBitWidth();
96 unsigned WideBits = 2 * Bits;
97 return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
100 APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
101 bool LZOptimization,
102 bool AllowEvenDivisorOptimization, bool ForceNPQ,
103 UnsignedDivisionByConstantInfo Magics) {
104 assert(!Divisor.isOne() && "Division by 1 is not supported using Magic.");
106 unsigned Bits = Numerator.getBitWidth();
108 if (LZOptimization) {
109 unsigned LeadingZeros = Numerator.countl_zero();
110 // Clip to the number of leading zeros in the divisor.
111 LeadingZeros = std::min(LeadingZeros, Divisor.countl_zero());
112 if (LeadingZeros > 0) {
113 Magics = UnsignedDivisionByConstantInfo::get(
114 Divisor, LeadingZeros, AllowEvenDivisorOptimization);
115 assert(!Magics.IsAdd && "Should use cheap fixup now");
119 assert(Magics.PreShift < Divisor.getBitWidth() &&
120 "We shouldn't generate an undefined shift!");
121 assert(Magics.PostShift < Divisor.getBitWidth() &&
122 "We shouldn't generate an undefined shift!");
123 assert((!Magics.IsAdd || Magics.PreShift == 0) && "Unexpected pre-shift");
124 unsigned PreShift = Magics.PreShift;
125 unsigned PostShift = Magics.PostShift;
126 bool UseNPQ = Magics.IsAdd;
128 APInt NPQFactor =
129 UseNPQ ? APInt::getSignedMinValue(Bits) : APInt::getZero(Bits);
131 APInt Q = Numerator.lshr(PreShift);
133 // Multiply the numerator by the magic value.
134 Q = MULHU(Q, Magics.Magic);
136 if (UseNPQ || ForceNPQ) {
137 APInt NPQ = Numerator - Q;
139 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
140 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
141 APInt NPQ_Scalar = NPQ.lshr(1);
142 (void)NPQ_Scalar;
143 NPQ = MULHU(NPQ, NPQFactor);
144 assert(!UseNPQ || NPQ == NPQ_Scalar);
146 Q = NPQ + Q;
149 Q = Q.lshr(PostShift);
151 return Q;
154 TEST(UnsignedDivisionByConstantTest, Test) {
155 for (unsigned Bits = 1; Bits <= 32; ++Bits) {
156 if (Bits < 2)
157 continue; // Not supported by `UnsignedDivisionByConstantInfo::get()`.
158 if (Bits > 10)
159 continue; // Unreasonably slow.
160 EnumerateAPInts(Bits, [Bits](const APInt &Divisor) {
161 if (Divisor.isZero())
162 return; // Division by zero is undefined behavior.
163 if (Divisor.isOne())
164 return; // Division by one is the numerator.
166 const UnsignedDivisionByConstantInfo Magics =
167 UnsignedDivisionByConstantInfo::get(Divisor);
168 EnumerateAPInts(Bits, [Divisor, Magics, Bits](const APInt &Numerator) {
169 APInt NativeResult = Numerator.udiv(Divisor);
170 for (bool LZOptimization : {true, false}) {
171 for (bool AllowEvenDivisorOptimization : {true, false}) {
172 for (bool ForceNPQ : {false, true}) {
173 APInt MagicResult = UnsignedDivideUsingMagic(
174 Numerator, Divisor, LZOptimization,
175 AllowEvenDivisorOptimization, ForceNPQ, Magics);
176 ASSERT_EQ(MagicResult, NativeResult)
177 << " ... given the operation: urem i" << Bits << " "
178 << Numerator << ", " << Divisor
179 << " (allow LZ optimization = "
180 << LZOptimization << ", allow even divisior optimization = "
181 << AllowEvenDivisorOptimization << ", force NPQ = "
182 << ForceNPQ << ")";
191 } // end anonymous namespace