[AMDGPU] Test codegen'ing True16 additions.
[llvm-project.git] / polly / lib / Transform / FlattenAlgo.cpp
blobf8ed332348ab1fa9f77edd1eac51047c3d502953
1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
10 // the unittest for this requiring linking against LLVM.
12 //===----------------------------------------------------------------------===//
14 #include "polly/FlattenAlgo.h"
15 #include "polly/Support/ISLOStream.h"
16 #include "polly/Support/ISLTools.h"
17 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "polly-flatten-algo"
20 using namespace polly;
21 using namespace llvm;
23 namespace {
25 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
26 /// i.e. there are two constants Min and Max, such that every value x of the
27 /// chosen dimensions is Min <= x <= Max.
28 bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
29 auto ParamDims = unsignedFromIslSize(Set.dim(isl::dim::param));
30 Set = Set.project_out(isl::dim::param, 0, ParamDims);
31 Set = Set.project_out(isl::dim::set, 0, dim);
32 auto SetDims = unsignedFromIslSize(Set.tuple_dim());
33 assert(SetDims >= 1);
34 Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
35 return bool(Set.is_bounded());
38 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
39 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
40 /// p, such that every value x of the chosen dimensions is
41 /// Min_p <= x <= Max_p.
42 bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
43 Set = Set.project_out(isl::dim::set, 0, dim);
44 auto SetDims = unsignedFromIslSize(Set.tuple_dim());
45 assert(SetDims >= 1);
46 Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
47 return bool(Set.is_bounded());
50 /// Whether BMap's first out-dimension is not a constant.
51 bool isVariableDim(const isl::basic_map &BMap) {
52 auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
53 return FixedVal.is_null() || FixedVal.is_nan();
56 /// Whether Map's first out dimension is no constant nor piecewise constant.
57 bool isVariableDim(const isl::map &Map) {
58 for (isl::basic_map BMap : Map.get_basic_map_list())
59 if (isVariableDim(BMap))
60 return false;
62 return true;
65 /// Whether UMap's first out dimension is no (piecewise) constant.
66 bool isVariableDim(const isl::union_map &UMap) {
67 for (isl::map Map : UMap.get_map_list())
68 if (isVariableDim(Map))
69 return false;
70 return true;
73 /// Compute @p UPwAff - @p Val.
74 isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
75 if (Val.is_zero())
76 return UPwAff;
78 auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
79 isl::stat Stat =
80 UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
81 auto ValAff =
82 isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
83 auto Subtracted = PwAff.sub(ValAff);
84 Result = Result.union_add(isl::union_pw_aff(Subtracted));
85 return isl::stat::ok();
86 });
87 if (Stat.is_error())
88 return {};
89 return Result;
92 /// Compute @UPwAff * @p Val.
93 isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
94 if (Val.is_one())
95 return UPwAff;
97 auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
98 isl::stat Stat =
99 UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
100 auto ValAff =
101 isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
102 auto Multiplied = PwAff.mul(ValAff);
103 Result = Result.union_add(Multiplied);
104 return isl::stat::ok();
106 if (Stat.is_error())
107 return {};
108 return Result;
111 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
113 /// It is assumed that all maps in the maps have at least the necessary number
114 /// of out dimensions.
115 isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
116 unsigned n) {
117 if (n == 0)
118 return UMap; /* isl_map_project_out would also reset the tuple, which should
119 have no effect on schedule ranges */
121 auto Result = isl::union_map::empty(UMap.ctx());
122 for (isl::map Map : UMap.get_map_list()) {
123 auto Outprojected = Map.project_out(isl::dim::out, first, n);
124 Result = Result.unite(Outprojected);
126 return Result;
129 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
130 isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
131 auto SingleUMap = isl::union_map::empty(UMap.ctx());
132 for (isl::map Map : UMap.get_map_list()) {
133 unsigned MapDims = unsignedFromIslSize(Map.range_tuple_dim());
134 assert(MapDims > pos);
135 isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
136 SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
137 SingleUMap = SingleUMap.unite(SingleMap);
140 auto UAff = isl::union_pw_multi_aff(SingleUMap);
141 auto FirstMAff = isl::multi_union_pw_aff(UAff);
142 return FirstMAff.at(0);
145 /// Flatten a sequence-like first dimension.
147 /// A sequence-like scatter dimension is constant, or at least only small
148 /// variation, typically the result of ordering a sequence of different
149 /// statements. An example would be:
150 /// { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
151 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
153 /// To flatten, first begin with an offset of zero. Then determine the lowest
154 /// possible value of the dimension, call it "i" [In the example we start at 0].
155 /// Considering only schedules with that value, consider only instances with
156 /// that value and determine the extent of the next dimension. Let l_X(i) and
157 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
158 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
159 /// to Offset and remove all i-instances from the old schedule. Repeat with the
160 /// remaining lowest value i' until there are no instances in the old schedule
161 /// left.
162 /// The example schedule would be transformed to:
163 /// { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
164 isl::union_map tryFlattenSequence(isl::union_map Schedule) {
165 auto IslCtx = Schedule.ctx();
166 auto ScatterSet = isl::set(Schedule.range());
168 auto ParamSpace = Schedule.get_space().params();
169 auto Dims = unsignedFromIslSize(ScatterSet.tuple_dim());
170 assert(Dims >= 2u);
172 // Would cause an infinite loop.
173 if (!isDimBoundedByConstant(ScatterSet, 0)) {
174 LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
175 return {};
178 auto AllDomains = Schedule.domain();
179 auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
181 auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
182 auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
184 while (!ScatterSet.is_empty()) {
185 LLVM_DEBUG(dbgs() << "Next counter:\n " << Counter << "\n");
186 LLVM_DEBUG(dbgs() << "Remaining scatter set:\n " << ScatterSet << "\n");
187 auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
188 auto ThisFirst = ThisSet.lexmin();
189 auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
191 auto SubSchedule = Schedule.intersect_range(ScatterFirst);
192 SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
193 SubSchedule = flattenSchedule(SubSchedule);
195 unsigned SubDims = getNumScatterDims(SubSchedule);
196 assert(SubDims >= 1);
197 auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
198 auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
199 auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
201 auto FirstSubScatter = isl::set(FirstSubSchedule.range());
202 LLVM_DEBUG(dbgs() << "Next step in sequence is:\n " << FirstSubScatter
203 << "\n");
205 if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
206 LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
207 return {};
210 auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
212 // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
213 // 'none'. It doesn't match with any space including a 0-dimensional
214 // anonymous tuple.
215 // Interesting, one can create such a set using
216 // isl_set_universe(ParamSpace). Bug?
217 auto PartMin = FirstSubScatterMap.dim_min(0);
218 auto PartMax = FirstSubScatterMap.dim_max(0);
219 auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
220 isl::val::one(IslCtx));
221 auto PartLen = PartMax.add(PartMin.neg()).add(One);
223 auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
224 auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
225 auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
226 auto FirstScheduleAffWithOffset =
227 FirstScheduleAffNormalized.add(AllCounter);
229 auto ScheduleWithOffset =
230 isl::union_map::from(
231 isl::union_pw_multi_aff(FirstScheduleAffWithOffset))
232 .flat_range_product(RemainingSubSchedule);
233 NewSchedule = NewSchedule.unite(ScheduleWithOffset);
235 ScatterSet = ScatterSet.subtract(ScatterFirst);
236 Counter = Counter.add(PartLen);
239 LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n " << NewSchedule
240 << "\n");
241 return NewSchedule;
244 /// Flatten a loop-like first dimension.
246 /// A loop-like dimension is one that depends on a variable (usually a loop's
247 /// induction variable). Let the input schedule look like this:
248 /// { Stmt[i] -> [i, X, ...] }
250 /// To flatten, we determine the largest extent of X which may not depend on the
251 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
252 /// largest value. Then, construct a new schedule
253 /// { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
254 isl::union_map tryFlattenLoop(isl::union_map Schedule) {
255 assert(getNumScatterDims(Schedule) >= 2);
257 auto Remaining = scheduleProjectOut(Schedule, 0, 1);
258 auto SubSchedule = flattenSchedule(Remaining);
259 unsigned SubDims = getNumScatterDims(SubSchedule);
261 assert(SubDims >= 1);
263 auto SubExtent = isl::set(SubSchedule.range());
264 auto SubExtentDims = unsignedFromIslSize(SubExtent.dim(isl::dim::param));
265 SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
266 SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
268 if (!isDimBoundedByConstant(SubExtent, 0)) {
269 LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
270 return {};
273 auto Min = SubExtent.dim_min(0);
274 LLVM_DEBUG(dbgs() << "Min bound:\n " << Min << "\n");
275 auto MinVal = getConstant(Min, false, true);
276 auto Max = SubExtent.dim_max(0);
277 LLVM_DEBUG(dbgs() << "Max bound:\n " << Max << "\n");
278 auto MaxVal = getConstant(Max, true, false);
280 if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
281 MaxVal.is_nan()) {
282 LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
283 return {};
286 auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
287 auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
289 auto LenVal = MaxVal.sub(MinVal).add(1);
290 auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
292 // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
293 // subtract it)
294 auto FirstAff = scheduleExtractDimAff(Schedule, 0);
295 auto Offset = multiply(FirstAff, LenVal);
296 isl::union_pw_multi_aff Index = FirstSubScheduleNormalized.add(Offset);
297 auto IndexMap = isl::union_map::from(Index);
299 auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
300 LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n " << Result << "\n");
301 return Result;
303 } // anonymous namespace
305 isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
306 unsigned Dims = getNumScatterDims(Schedule);
307 LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n " << Schedule
308 << "\n");
310 // Base case; no dimensions left
311 if (Dims == 0) {
312 // TODO: Add one dimension?
313 return Schedule;
316 // Base case; already one-dimensional
317 if (Dims == 1)
318 return Schedule;
320 // Fixed dimension; no need to preserve variabledness.
321 if (!isVariableDim(Schedule)) {
322 LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
323 auto NewScheduleSequence = tryFlattenSequence(Schedule);
324 if (!NewScheduleSequence.is_null())
325 return NewScheduleSequence;
328 // Constant stride
329 LLVM_DEBUG(dbgs() << "Try loop flattening\n");
330 auto NewScheduleLoop = tryFlattenLoop(Schedule);
331 if (!NewScheduleLoop.is_null())
332 return NewScheduleLoop;
334 // Try again without loop condition (may blow up the number of pieces!!)
335 LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
336 auto NewScheduleSequence = tryFlattenSequence(Schedule);
337 if (!NewScheduleSequence.is_null())
338 return NewScheduleSequence;
340 // Cannot flatten
341 return Schedule;