[AMDGPU] Test codegen'ing True16 additions.
[llvm-project.git] / polly / lib / Transform / ScheduleTreeTransform.cpp
blobe42b3d1c24604bf29929116b9157f43d591adfef
1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Make changes to isl's schedule tree data structure.
11 //===----------------------------------------------------------------------===//
13 #include "polly/ScheduleTreeTransform.h"
14 #include "polly/Support/GICHelper.h"
15 #include "polly/Support/ISLTools.h"
16 #include "polly/Support/ScopHelper.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/Transforms/Utils/UnrollLoop.h"
24 #define DEBUG_TYPE "polly-opt-isl"
26 using namespace polly;
27 using namespace llvm;
29 namespace {
31 /// Copy the band member attributes (coincidence, loop type, isolate ast loop
32 /// type) from one band to another.
33 static isl::schedule_node_band
34 applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx,
35 const isl::schedule_node_band &Source,
36 int SourceIdx) {
37 bool Coincident = Source.member_get_coincident(SourceIdx).release();
38 Target = Target.member_set_coincident(TargetIdx, Coincident);
40 isl_ast_loop_type LoopType =
41 isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx);
42 Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type(
43 Target.release(), TargetIdx, LoopType))
44 .as<isl::schedule_node_band>();
46 isl_ast_loop_type IsolateType =
47 isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(),
48 SourceIdx);
49 Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type(
50 Target.release(), TargetIdx, IsolateType))
51 .as<isl::schedule_node_band>();
53 return Target;
56 /// Create a new band by copying members from another @p Band. @p IncludeCb
57 /// decides which band indices are copied to the result.
58 template <typename CbTy>
59 static isl::schedule rebuildBand(isl::schedule_node_band OldBand,
60 isl::schedule Body, CbTy IncludeCb) {
61 int NumBandDims = unsignedFromIslSize(OldBand.n_member());
63 bool ExcludeAny = false;
64 bool IncludeAny = false;
65 for (auto OldIdx : seq<int>(0, NumBandDims)) {
66 if (IncludeCb(OldIdx))
67 IncludeAny = true;
68 else
69 ExcludeAny = true;
72 // Instead of creating a zero-member band, don't create a band at all.
73 if (!IncludeAny)
74 return Body;
76 isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule();
77 isl::multi_union_pw_aff NewPartialSched;
78 if (ExcludeAny) {
79 // Select the included partial scatter functions.
80 isl::union_pw_aff_list List = PartialSched.list();
81 int NewIdx = 0;
82 for (auto OldIdx : seq<int>(0, NumBandDims)) {
83 if (IncludeCb(OldIdx))
84 NewIdx += 1;
85 else
86 List = List.drop(NewIdx, 1);
88 isl::space ParamSpace = PartialSched.get_space().params();
89 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx);
90 NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List);
91 } else {
92 // Just reuse original scatter function of copying all of them.
93 NewPartialSched = PartialSched;
96 // Create the new band node.
97 isl::schedule_node_band NewBand =
98 Body.insert_partial_schedule(NewPartialSched)
99 .get_root()
100 .child(0)
101 .as<isl::schedule_node_band>();
103 // If OldBand was permutable, so is the new one, even if some dimensions are
104 // missing.
105 bool IsPermutable = OldBand.permutable().release();
106 NewBand = NewBand.set_permutable(IsPermutable);
108 // Reapply member attributes.
109 int NewIdx = 0;
110 for (auto OldIdx : seq<int>(0, NumBandDims)) {
111 if (!IncludeCb(OldIdx))
112 continue;
113 NewBand =
114 applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx);
115 NewIdx += 1;
118 return NewBand.get_schedule();
121 /// Rewrite a schedule tree by reconstructing it bottom-up.
123 /// By default, the original schedule tree is reconstructed. To build a
124 /// different tree, redefine visitor methods in a derived class (CRTP).
126 /// Note that AST build options are not applied; Setting the isolate[] option
127 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
128 /// AST build options must be set after the tree has been constructed.
129 template <typename Derived, typename... Args>
130 struct ScheduleTreeRewriter
131 : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
132 Derived &getDerived() { return *static_cast<Derived *>(this); }
133 const Derived &getDerived() const {
134 return *static_cast<const Derived *>(this);
137 isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
138 // Every schedule_tree already has a domain node, no need to add one.
139 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
142 isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
143 isl::schedule NewChild =
144 getDerived().visit(Band.child(0), std::forward<Args>(args)...);
145 return rebuildBand(Band, NewChild, [](int) { return true; });
148 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
149 Args... args) {
150 int NumChildren = isl_schedule_node_n_children(Sequence.get());
151 isl::schedule Result =
152 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
153 for (int i = 1; i < NumChildren; i += 1)
154 Result = Result.sequence(
155 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
156 return Result;
159 isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
160 int NumChildren = isl_schedule_node_n_children(Set.get());
161 isl::schedule Result =
162 getDerived().visit(Set.child(0), std::forward<Args>(args)...);
163 for (int i = 1; i < NumChildren; i += 1)
164 Result = isl::manage(
165 isl_schedule_set(Result.release(),
166 getDerived()
167 .visit(Set.child(i), std::forward<Args>(args)...)
168 .release()));
169 return Result;
172 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
173 return isl::schedule::from_domain(Leaf.get_domain());
176 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
178 isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
179 isl::schedule_node NewChild =
180 getDerived()
181 .visit(Mark.first_child(), std::forward<Args>(args)...)
182 .get_root()
183 .first_child();
184 return NewChild.insert_mark(TheMark).get_schedule();
187 isl::schedule visitExtension(isl::schedule_node_extension Extension,
188 Args... args) {
189 isl::union_map TheExtension =
190 Extension.as<isl::schedule_node_extension>().get_extension();
191 isl::schedule_node NewChild = getDerived()
192 .visit(Extension.child(0), args...)
193 .get_root()
194 .first_child();
195 isl::schedule_node NewExtension =
196 isl::schedule_node::from_extension(TheExtension);
197 return NewChild.graft_before(NewExtension).get_schedule();
200 isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
201 isl::union_set FilterDomain =
202 Filter.as<isl::schedule_node_filter>().get_filter();
203 isl::schedule NewSchedule =
204 getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
205 return NewSchedule.intersect_domain(FilterDomain);
208 isl::schedule visitNode(isl::schedule_node Node, Args... args) {
209 llvm_unreachable("Not implemented");
213 /// Rewrite the schedule tree without any changes. Useful to copy a subtree into
214 /// a new schedule, discarding everything but.
215 struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {};
217 /// Rewrite a schedule tree to an equivalent one without extension nodes.
219 /// Each visit method takes two additional arguments:
221 /// * The new domain the node, which is the inherited domain plus any domains
222 /// added by extension nodes.
224 /// * A map of extension domains of all children is returned; it is required by
225 /// band nodes to schedule the additional domains at the same position as the
226 /// extension node would.
228 struct ExtensionNodeRewriter final
229 : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
230 isl::union_map &> {
231 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
232 const isl::union_set &, isl::union_map &>;
233 BaseTy &getBase() { return *this; }
234 const BaseTy &getBase() const { return *this; }
236 isl::schedule visitSchedule(isl::schedule Schedule) {
237 isl::union_map Extensions;
238 isl::schedule Result =
239 visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
240 assert(!Extensions.is_null() && Extensions.is_empty());
241 return Result;
244 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
245 const isl::union_set &Domain,
246 isl::union_map &Extensions) {
247 int NumChildren = isl_schedule_node_n_children(Sequence.get());
248 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
249 for (int i = 1; i < NumChildren; i += 1) {
250 isl::schedule_node OldChild = Sequence.child(i);
251 isl::union_map NewChildExtensions;
252 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
253 NewNode = NewNode.sequence(NewChildNode);
254 Extensions = Extensions.unite(NewChildExtensions);
256 return NewNode;
259 isl::schedule visitSet(isl::schedule_node_set Set,
260 const isl::union_set &Domain,
261 isl::union_map &Extensions) {
262 int NumChildren = isl_schedule_node_n_children(Set.get());
263 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
264 for (int i = 1; i < NumChildren; i += 1) {
265 isl::schedule_node OldChild = Set.child(i);
266 isl::union_map NewChildExtensions;
267 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
268 NewNode = isl::manage(
269 isl_schedule_set(NewNode.release(), NewChildNode.release()));
270 Extensions = Extensions.unite(NewChildExtensions);
272 return NewNode;
275 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
276 const isl::union_set &Domain,
277 isl::union_map &Extensions) {
278 Extensions = isl::union_map::empty(Leaf.ctx());
279 return isl::schedule::from_domain(Domain);
282 isl::schedule visitBand(isl::schedule_node_band OldNode,
283 const isl::union_set &Domain,
284 isl::union_map &OuterExtensions) {
285 isl::schedule_node OldChild = OldNode.first_child();
286 isl::multi_union_pw_aff PartialSched =
287 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
289 isl::union_map NewChildExtensions;
290 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
292 // Add the extensions to the partial schedule.
293 OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx());
294 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
295 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
296 for (isl::map Ext : NewChildExtensions.get_map_list()) {
297 unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim());
298 assert(ExtDims >= BandDims);
299 unsigned OuterDims = ExtDims - BandDims;
301 isl::map BandSched =
302 Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
303 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
305 // There might be more outer bands that have to schedule the extensions.
306 if (OuterDims > 0) {
307 isl::map OuterSched =
308 Ext.project_out(isl::dim::in, OuterDims, BandDims);
309 OuterExtensions = OuterExtensions.unite(OuterSched);
312 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
313 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
314 isl::schedule_node NewNode =
315 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
316 .get_root()
317 .child(0);
319 // Reapply permutability and coincidence attributes.
320 NewNode = isl::manage(isl_schedule_node_band_set_permutable(
321 NewNode.release(),
322 isl_schedule_node_band_get_permutable(OldNode.get())));
323 for (unsigned i = 0; i < BandDims; i += 1)
324 NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(),
325 i, OldNode, i);
327 return NewNode.get_schedule();
330 isl::schedule visitFilter(isl::schedule_node_filter Filter,
331 const isl::union_set &Domain,
332 isl::union_map &Extensions) {
333 isl::union_set FilterDomain =
334 Filter.as<isl::schedule_node_filter>().get_filter();
335 isl::union_set NewDomain = Domain.intersect(FilterDomain);
337 // A filter is added implicitly if necessary when joining schedule trees.
338 return visit(Filter.first_child(), NewDomain, Extensions);
341 isl::schedule visitExtension(isl::schedule_node_extension Extension,
342 const isl::union_set &Domain,
343 isl::union_map &Extensions) {
344 isl::union_map ExtDomain =
345 Extension.as<isl::schedule_node_extension>().get_extension();
346 isl::union_set NewDomain = Domain.unite(ExtDomain.range());
347 isl::union_map ChildExtensions;
348 isl::schedule NewChild =
349 visit(Extension.first_child(), NewDomain, ChildExtensions);
350 Extensions = ChildExtensions.unite(ExtDomain);
351 return NewChild;
355 /// Collect all AST build options in any schedule tree band.
357 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
358 /// collects these options to apply them later.
359 struct CollectASTBuildOptions final
360 : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
361 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
362 BaseTy &getBase() { return *this; }
363 const BaseTy &getBase() const { return *this; }
365 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
367 void visitBand(isl::schedule_node_band Band) {
368 ASTBuildOptions.push_back(
369 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
370 return getBase().visitBand(Band);
374 /// Apply AST build options to the bands in a schedule tree.
376 /// This rewrites a schedule tree with the AST build options applied. We assume
377 /// that the band nodes are visited in the same order as they were when the
378 /// build options were collected, typically by CollectASTBuildOptions.
379 struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> {
380 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
381 BaseTy &getBase() { return *this; }
382 const BaseTy &getBase() const { return *this; }
384 size_t Pos;
385 llvm::ArrayRef<isl::union_set> ASTBuildOptions;
387 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
388 : ASTBuildOptions(ASTBuildOptions) {}
390 isl::schedule visitSchedule(isl::schedule Schedule) {
391 Pos = 0;
392 isl::schedule Result = visit(Schedule).get_schedule();
393 assert(Pos == ASTBuildOptions.size() &&
394 "AST build options must match to band nodes");
395 return Result;
398 isl::schedule_node visitBand(isl::schedule_node_band Band) {
399 isl::schedule_node_band Result =
400 Band.set_ast_build_options(ASTBuildOptions[Pos]);
401 Pos += 1;
402 return getBase().visitBand(Result);
406 /// Return whether the schedule contains an extension node.
407 static bool containsExtensionNode(isl::schedule Schedule) {
408 assert(!Schedule.is_null());
410 auto Callback = [](__isl_keep isl_schedule_node *Node,
411 void *User) -> isl_bool {
412 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
413 // Stop walking the schedule tree.
414 return isl_bool_error;
417 // Continue searching the subtree.
418 return isl_bool_true;
420 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
421 Schedule.get(), Callback, nullptr);
423 // We assume that the traversal itself does not fail, i.e. the only reason to
424 // return isl_stat_error is that an extension node was found.
425 return RetVal == isl_stat_error;
428 /// Find a named MDNode property in a LoopID.
429 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
430 return dyn_cast_or_null<MDNode>(
431 findMetadataOperand(LoopMD, Name).value_or(nullptr));
434 /// Is this node of type mark?
435 static bool isMark(const isl::schedule_node &Node) {
436 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
439 /// Is this node of type band?
440 static bool isBand(const isl::schedule_node &Node) {
441 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
444 #ifndef NDEBUG
445 /// Is this node a band of a single dimension (i.e. could represent a loop)?
446 static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
447 return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
449 #endif
451 static bool isLeaf(const isl::schedule_node &Node) {
452 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf;
455 /// Create an isl::id representing the output loop after a transformation.
456 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
457 // Don't need to id the followup.
458 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
459 // user followup-MD
460 if (!FollowupLoopMD)
461 return {};
463 BandAttr *Attr = new BandAttr();
464 Attr->Metadata = FollowupLoopMD;
465 return getIslLoopAttr(Ctx, Attr);
468 /// A loop consists of a band and an optional marker that wraps it. Return the
469 /// outermost of the two.
471 /// That is, either the mark or, if there is not mark, the loop itself. Can
472 /// start with either the mark or the band.
473 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
474 if (isBandMark(BandOrMark)) {
475 assert(isBandWithSingleLoop(BandOrMark.child(0)));
476 return BandOrMark;
478 assert(isBandWithSingleLoop(BandOrMark));
480 isl::schedule_node Mark = BandOrMark.parent();
481 if (isBandMark(Mark))
482 return Mark;
484 // Band has no loop marker.
485 return BandOrMark;
488 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
489 BandAttr *&Attr) {
490 MarkOrBand = moveToBandMark(MarkOrBand);
492 isl::schedule_node Band;
493 if (isMark(MarkOrBand)) {
494 Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
495 Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
496 } else {
497 Attr = nullptr;
498 Band = MarkOrBand;
501 assert(isBandWithSingleLoop(Band));
502 return Band;
505 /// Remove the mark that wraps a loop. Return the band representing the loop.
506 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
507 BandAttr *Attr;
508 return removeMark(MarkOrBand, Attr);
511 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
512 assert(isBand(Band));
513 assert(moveToBandMark(Band).is_equal(Band) &&
514 "Don't add a two marks for a band");
516 return Band.insert_mark(Mark).child(0);
519 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor
520 /// with remainder @p Offset.
522 /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
523 /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
525 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
526 long Offset) {
527 isl::val ValFactor{Ctx, Factor};
528 isl::val ValOffset{Ctx, Offset};
530 isl::space Unispace{Ctx, 0, 1};
531 isl::local_space LUnispace{Unispace};
532 isl::aff AffFactor{LUnispace, ValFactor};
533 isl::aff AffOffset{LUnispace, ValOffset};
535 isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
536 isl::aff DivMul = Id.mod(ValFactor);
537 isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
538 isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
539 return Modulo.domain();
542 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
544 /// @param Set A set, which should be modified.
545 /// @param VectorWidth A parameter, which determines the constraint.
546 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
547 unsigned Dims = unsignedFromIslSize(Set.tuple_dim());
548 assert(Dims >= 1);
549 isl::space Space = Set.get_space();
550 isl::local_space LocalSpace = isl::local_space(Space);
551 isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
552 ExtConstr = ExtConstr.set_constant_si(0);
553 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
554 Set = Set.add_constraint(ExtConstr);
555 ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
556 ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
557 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
558 return Set.add_constraint(ExtConstr);
561 /// Collapse perfectly nested bands into a single band.
562 class BandCollapseRewriter final
563 : public ScheduleTreeRewriter<BandCollapseRewriter> {
564 private:
565 using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
566 BaseTy &getBase() { return *this; }
567 const BaseTy &getBase() const { return *this; }
569 public:
570 isl::schedule visitBand(isl::schedule_node_band RootBand) {
571 isl::schedule_node_band Band = RootBand;
572 isl::ctx Ctx = Band.ctx();
574 // Do not merge permutable band to avoid loosing the permutability property.
575 // Cannot collapse even two permutable loops, they might be permutable
576 // individually, but not necassarily across.
577 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
578 return getBase().visitBand(Band);
580 // Find collapsable bands.
581 SmallVector<isl::schedule_node_band> Nest;
582 int NumTotalLoops = 0;
583 isl::schedule_node Body;
584 while (true) {
585 Nest.push_back(Band);
586 NumTotalLoops += unsignedFromIslSize(Band.n_member());
587 Body = Band.first_child();
588 if (!Body.isa<isl::schedule_node_band>())
589 break;
590 Band = Body.as<isl::schedule_node_band>();
592 // Do not include next band if it is permutable to not lose its
593 // permutability property.
594 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
595 break;
598 // Nothing to collapse, preserve permutability.
599 if (Nest.size() <= 1)
600 return getBase().visitBand(Band);
602 LLVM_DEBUG({
603 dbgs() << "Found loops to collapse between\n";
604 dumpIslObj(RootBand, dbgs());
605 dbgs() << "and\n";
606 dumpIslObj(Body, dbgs());
607 dbgs() << "\n";
610 isl::schedule NewBody = visit(Body);
612 // Collect partial schedules from all members.
613 isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops};
614 for (isl::schedule_node_band Band : Nest) {
615 int NumLoops = unsignedFromIslSize(Band.n_member());
616 isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule();
617 for (auto j : seq<int>(0, NumLoops))
618 PartScheds = PartScheds.add(BandScheds.at(j));
620 isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops);
621 isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds};
623 isl::schedule_node_band CollapsedBand =
624 NewBody.insert_partial_schedule(PartSchedsMulti)
625 .get_root()
626 .first_child()
627 .as<isl::schedule_node_band>();
629 // Copy over loop attributes form original bands.
630 int LoopIdx = 0;
631 for (isl::schedule_node_band Band : Nest) {
632 int NumLoops = unsignedFromIslSize(Band.n_member());
633 for (int i : seq<int>(0, NumLoops)) {
634 CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand),
635 LoopIdx, Band, i);
636 LoopIdx += 1;
639 assert(LoopIdx == NumTotalLoops &&
640 "Expect the same number of loops to add up again");
642 return CollapsedBand.get_schedule();
646 static isl::schedule collapseBands(isl::schedule Sched) {
647 LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n");
648 BandCollapseRewriter Rewriter;
649 return Rewriter.visit(Sched);
652 /// Collect sequentially executed bands (or anything else), even if nested in a
653 /// mark or other nodes whose child is executed just once. If we can
654 /// successfully fuse the bands, we allow them to be removed.
655 static void collectPotentiallyFusableBands(
656 isl::schedule_node Node,
657 SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
658 &ScheduleBands,
659 const isl::schedule_node &DirectChild) {
660 switch (isl_schedule_node_get_type(Node.get())) {
661 case isl_schedule_node_sequence:
662 case isl_schedule_node_set:
663 case isl_schedule_node_mark:
664 case isl_schedule_node_domain:
665 case isl_schedule_node_filter:
666 if (Node.has_children()) {
667 isl::schedule_node C = Node.first_child();
668 while (true) {
669 collectPotentiallyFusableBands(C, ScheduleBands, DirectChild);
670 if (!C.has_next_sibling())
671 break;
672 C = C.next_sibling();
675 break;
677 default:
678 // Something that does not execute suquentially (e.g. a band)
679 ScheduleBands.push_back({Node, DirectChild});
680 break;
684 /// Remove dependencies that are resolved by @p PartSched. That is, remove
685 /// everything that we already know is executed in-order.
686 static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched,
687 isl::union_map Deps) {
688 unsigned NumDims = getNumScatterDims(PartSched);
689 auto ParamSpace = PartSched.get_space().params();
691 // { Scatter[] }
692 isl::space ScatterSpace =
693 ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims);
695 // { Scatter[] -> Domain[] }
696 isl::union_map PartSchedRev = PartSched.reverse();
698 // { Scatter[] -> Scatter[] }
699 isl::map MaybeBefore = isl::map::lex_le(ScatterSpace);
701 // { Domain[] -> Domain[] }
702 isl::union_map DomMaybeBefore =
703 MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev);
705 // { Domain[] -> Domain[] }
706 isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore);
708 return ChildRemainingDeps;
711 /// Remove dependencies that are resolved by executing them in the order
712 /// specified by @p Domains;
713 static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
714 isl::union_map Deps) {
715 isl::ctx Ctx = Deps.ctx();
716 isl::space ParamSpace = Deps.get_space().params();
718 // Create a partial schedule mapping to constants that reflect the execution
719 // order.
720 isl::union_map PartialSchedules = isl::union_map::empty(Ctx);
721 for (auto P : enumerate(Domains)) {
722 isl::val ExecTime = isl::val(Ctx, P.index());
723 isl::union_pw_aff DomSched{P.value(), ExecTime};
724 PartialSchedules = PartialSchedules.unite(DomSched.as_union_map());
727 return remainingDepsFromPartialSchedule(PartialSchedules, Deps);
730 /// Determine whether the outermost loop of to bands can be fused while
731 /// respecting validity dependencies.
732 static bool canFuseOutermost(const isl::schedule_node_band &LHS,
733 const isl::schedule_node_band &RHS,
734 const isl::union_map &Deps) {
735 // { LHSDomain[] -> Scatter[] }
736 isl::union_map LHSPartSched =
737 LHS.get_partial_schedule().get_at(0).as_union_map();
739 // { Domain[] -> Scatter[] }
740 isl::union_map RHSPartSched =
741 RHS.get_partial_schedule().get_at(0).as_union_map();
743 // Dependencies that are already resolved because LHS executes before RHS, but
744 // will not be anymore after fusion. { DefDomain[] -> UseDomain[] }
745 isl::union_map OrderedBySequence =
746 Deps.intersect_domain(LHSPartSched.domain())
747 .intersect_range(RHSPartSched.domain());
749 isl::space ParamSpace = OrderedBySequence.get_space().params();
750 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1);
752 // { Scatter[] -> Scatter[] }
753 isl::map After = isl::map::lex_gt(NewScatterSpace);
755 // After fusion, instances with smaller (or equal, which means they will be
756 // executed in the same iteration, but the LHS instance is still sequenced
757 // before RHS) scatter value will still be executed before. This are the
758 // orderings where this is not necessarily the case.
759 // { LHSDomain[] -> RHSDomain[] }
760 isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse())
761 .apply_range(RHSPartSched.reverse());
763 // Dependencies that are not resolved by the new execution order.
764 isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms);
766 return WithBefore.is_empty();
769 /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies.
770 static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS,
771 isl::schedule_node_band RHS,
772 const isl::union_map &Deps) {
773 if (!canFuseOutermost(LHS, RHS, Deps))
774 return {};
776 LLVM_DEBUG({
777 dbgs() << "Found loops for greedy fusion:\n";
778 dumpIslObj(LHS, dbgs());
779 dbgs() << "and\n";
780 dumpIslObj(RHS, dbgs());
781 dbgs() << "\n";
784 // The partial schedule of the bands outermost loop that we need to combine
785 // for the fusion.
786 isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0);
787 isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0);
789 // Isolate band bodies as roots of their own schedule trees.
790 IdentityRewriter Rewriter;
791 isl::schedule LHSBody = Rewriter.visit(LHS.first_child());
792 isl::schedule RHSBody = Rewriter.visit(RHS.first_child());
794 // Reconstruct the non-outermost (not going to be fused) loops from both
795 // bands.
796 // TODO: Maybe it is possibly to transfer the 'permutability' property from
797 // LHS+RHS. At minimum we need merge multiple band members at once, otherwise
798 // permutability has no meaning.
799 isl::schedule LHSNewBody =
800 rebuildBand(LHS, LHSBody, [](int i) { return i > 0; });
801 isl::schedule RHSNewBody =
802 rebuildBand(RHS, RHSBody, [](int i) { return i > 0; });
804 // The loop body of the fused loop.
805 isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody);
807 // Combine the partial schedules of both loops to a new one. Instances with
808 // the same scatter value are put together.
809 isl::union_map NewCommonPartialSched =
810 LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map());
811 isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule(
812 NewCommonPartialSched.as_multi_union_pw_aff());
814 return NewCommonSchedule;
817 static isl::schedule tryGreedyFuse(isl::schedule_node LHS,
818 isl::schedule_node RHS,
819 const isl::union_map &Deps) {
820 // TODO: Non-bands could be interpreted as a band with just as single
821 // iteration. However, this is only useful if both ends of a fused loop were
822 // originally loops themselves.
823 if (!LHS.isa<isl::schedule_node_band>())
824 return {};
825 if (!RHS.isa<isl::schedule_node_band>())
826 return {};
828 return tryGreedyFuse(LHS.as<isl::schedule_node_band>(),
829 RHS.as<isl::schedule_node_band>(), Deps);
832 /// Fuse all fusable loop top-down in a schedule tree.
834 /// The isl::union_map parameters is the set of validity dependencies that have
835 /// not been resolved/carried by a parent schedule node.
836 class GreedyFusionRewriter final
837 : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
838 private:
839 using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
840 BaseTy &getBase() { return *this; }
841 const BaseTy &getBase() const { return *this; }
843 public:
844 /// Is set to true if anything has been fused.
845 bool AnyChange = false;
847 isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) {
848 // { Domain[] -> Scatter[] }
849 isl::union_map PartSched =
850 isl::union_map::from(Band.get_partial_schedule());
851 assert(getNumScatterDims(PartSched) ==
852 unsignedFromIslSize(Band.n_member()));
853 isl::space ParamSpace = PartSched.get_space().params();
855 // { Scatter[] -> Domain[] }
856 isl::union_map PartSchedRev = PartSched.reverse();
858 // Possible within the same iteration. Dependencies with smaller scatter
859 // value are carried by this loop and therefore have been resolved by the
860 // in-order execution if the loop iteration. A dependency with small scatter
861 // value would be a dependency violation that we assume did not happen. {
862 // Domain[] -> Domain[] }
863 isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev);
865 // Actual dependencies within the same iteration.
866 // { DefDomain[] -> UseDomain[] }
867 isl::union_map RemDeps = Deps.intersect(Unsequenced);
869 return getBase().visitBand(Band, RemDeps);
872 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
873 isl::union_map Deps) {
874 int NumChildren = isl_schedule_node_n_children(Sequence.get());
876 // List of fusion candidates. The first element is the fusion candidate, the
877 // second is candidate's ancestor that is the sequence's direct child. It is
878 // preferable to use the direct child if not if its non-direct children is
879 // fused to preserve its structure such as mark nodes.
880 SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
881 for (auto i : seq<int>(0, NumChildren)) {
882 isl::schedule_node Child = Sequence.child(i);
883 collectPotentiallyFusableBands(Child, Bands, Child);
886 // Direct children that had at least one of its decendants fused.
887 SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
889 // Fuse neigboring bands until reaching the end of candidates.
890 int i = 0;
891 while (i + 1 < (int)Bands.size()) {
892 isl::schedule Fused =
893 tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps);
894 if (Fused.is_null()) {
895 // Cannot merge this node with the next; look at next pair.
896 i += 1;
897 continue;
900 // Mark the direct children as (partially) fused.
901 if (!Bands[i].second.is_null())
902 ChangedDirectChildren.insert(Bands[i].second.get());
903 if (!Bands[i + 1].second.is_null())
904 ChangedDirectChildren.insert(Bands[i + 1].second.get());
906 // Collapse the neigbros to a single new candidate that could be fused
907 // with the next candidate.
908 Bands[i] = {Fused.get_root(), {}};
909 Bands.erase(Bands.begin() + i + 1);
911 AnyChange = true;
914 // By construction equal if done with collectPotentiallyFusableBands's
915 // output.
916 SmallVector<isl::union_set> SubDomains;
917 SubDomains.reserve(NumChildren);
918 for (int i = 0; i < NumChildren; i += 1)
919 SubDomains.push_back(Sequence.child(i).domain());
920 auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps);
922 // We may iterate over direct children multiple times, be sure to add each
923 // at most once.
924 SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
926 isl::schedule Result;
927 for (auto &P : Bands) {
928 isl::schedule_node MaybeFused = P.first;
929 isl::schedule_node DirectChild = P.second;
931 // If not modified, use the direct child.
932 if (!DirectChild.is_null() &&
933 !ChangedDirectChildren.count(DirectChild.get())) {
934 if (AlreadyAdded.count(DirectChild.get()))
935 continue;
936 AlreadyAdded.insert(DirectChild.get());
937 MaybeFused = DirectChild;
938 } else {
939 assert(AnyChange &&
940 "Need changed flag for be consistent with actual change");
943 // Top-down recursion: If the outermost loop has been fused, their nested
944 // bands might be fusable now as well.
945 isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps);
947 // Reconstruct the sequence, with some of the children fused.
948 if (Result.is_null())
949 Result = InnerFused;
950 else
951 Result = Result.sequence(InnerFused);
954 return Result;
958 } // namespace
960 bool polly::isBandMark(const isl::schedule_node &Node) {
961 return isMark(Node) &&
962 isLoopAttr(Node.as<isl::schedule_node_mark>().get_id());
965 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
966 MarkOrBand = moveToBandMark(MarkOrBand);
967 if (!isMark(MarkOrBand))
968 return nullptr;
970 return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
973 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
974 // If there is no extension node in the first place, return the original
975 // schedule tree.
976 if (!containsExtensionNode(Sched))
977 return Sched;
979 // Build options can anchor schedule nodes, such that the schedule tree cannot
980 // be modified anymore. Therefore, apply build options after the tree has been
981 // created.
982 CollectASTBuildOptions Collector;
983 Collector.visit(Sched);
985 // Rewrite the schedule tree without extension nodes.
986 ExtensionNodeRewriter Rewriter;
987 isl::schedule NewSched = Rewriter.visitSchedule(Sched);
989 // Reapply the AST build options. The rewriter must not change the iteration
990 // order of bands. Any other node type is ignored.
991 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
992 NewSched = Applicator.visitSchedule(NewSched);
994 return NewSched;
997 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
998 isl::ctx Ctx = BandToUnroll.ctx();
1000 // Remove the loop's mark, the loop will disappear anyway.
1001 BandToUnroll = removeMark(BandToUnroll);
1002 assert(isBandWithSingleLoop(BandToUnroll));
1004 isl::multi_union_pw_aff PartialSched = isl::manage(
1005 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
1006 assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u &&
1007 "Can only unroll a single dimension");
1008 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
1010 isl::union_set Domain = BandToUnroll.get_domain();
1011 PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
1012 isl::union_map PartialSchedUMap =
1013 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
1015 // Enumerator only the scatter elements.
1016 isl::union_set ScatterList = PartialSchedUMap.range();
1018 // Enumerate all loop iterations.
1019 // TODO: Diagnose if not enumerable or depends on a parameter.
1020 SmallVector<isl::point, 16> Elts;
1021 ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
1022 Elts.push_back(P);
1023 return isl::stat::ok();
1026 // Don't assume that foreach_point returns in execution order.
1027 llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
1028 isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
1029 isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
1030 return C1.lt(C2);
1033 // Convert the points to a sequence of filters.
1034 isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
1035 for (isl::point P : Elts) {
1036 // Determine the domains that map this scatter element.
1037 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
1039 List = List.add(DomainFilter);
1042 // Replace original band with unrolled sequence.
1043 isl::schedule_node Body =
1044 isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
1045 Body = Body.insert_sequence(List);
1046 return Body.get_schedule();
1049 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
1050 int Factor) {
1051 assert(Factor > 0 && "Positive unroll factor required");
1052 isl::ctx Ctx = BandToUnroll.ctx();
1054 // Remove the mark, save the attribute for later use.
1055 BandAttr *Attr;
1056 BandToUnroll = removeMark(BandToUnroll, Attr);
1057 assert(isBandWithSingleLoop(BandToUnroll));
1059 isl::multi_union_pw_aff PartialSched = isl::manage(
1060 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
1062 // { Stmt[] -> [x] }
1063 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
1065 // Here we assume the schedule stride is one and starts with 0, which is not
1066 // necessarily the case.
1067 isl::union_pw_aff StridedPartialSchedUAff =
1068 isl::union_pw_aff::empty(PartialSchedUAff.get_space());
1069 isl::val ValFactor{Ctx, Factor};
1070 PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
1071 &ValFactor](isl::pw_aff PwAff) -> isl::stat {
1072 isl::space Space = PwAff.get_space();
1073 isl::set Universe = isl::set::universe(Space.domain());
1074 isl::pw_aff AffFactor{Universe, ValFactor};
1075 isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
1076 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
1077 return isl::stat::ok();
1080 isl::union_set_list List = isl::union_set_list(Ctx, Factor);
1081 for (auto i : seq<int>(0, Factor)) {
1082 // { Stmt[] -> [x] }
1083 isl::union_map UMap =
1084 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
1086 // { [x] }
1087 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
1089 // { Stmt[] }
1090 isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
1092 List = List.add(UnrolledDomain);
1095 isl::schedule_node Body =
1096 isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
1097 Body = Body.insert_sequence(List);
1098 isl::schedule_node NewLoop =
1099 Body.insert_partial_schedule(StridedPartialSchedUAff);
1101 MDNode *FollowupMD = nullptr;
1102 if (Attr && Attr->Metadata)
1103 FollowupMD =
1104 findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
1106 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
1107 if (!NewBandId.is_null())
1108 NewLoop = insertMark(NewLoop, NewBandId);
1110 return NewLoop.get_schedule();
1113 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
1114 int VectorWidth) {
1115 unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim());
1116 assert(Dims >= 1);
1117 isl::set LoopPrefixes =
1118 ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
1119 auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
1120 isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
1121 BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1122 LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1123 return LoopPrefixes.subtract(BadPrefixes);
1126 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
1127 unsigned OutDimsNum) {
1128 unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim());
1129 assert(OutDimsNum <= Dims &&
1130 "The isl::set IsolateDomain is used to describe the range of schedule "
1131 "dimensions values, which should be isolated. Consequently, the "
1132 "number of its dimensions should be greater than or equal to the "
1133 "number of the schedule dimensions.");
1134 isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
1135 IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
1136 Dims - OutDimsNum, OutDimsNum);
1137 isl::set IsolateOption = IsolateRelation.wrap();
1138 isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr);
1139 IsolateOption = IsolateOption.set_tuple_id(Id);
1140 return isl::union_set(IsolateOption);
1143 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
1144 isl::space Space(Ctx, 0, 1);
1145 auto DimOption = isl::set::universe(Space);
1146 auto Id = isl::id::alloc(Ctx, Option, nullptr);
1147 DimOption = DimOption.set_tuple_id(Id);
1148 return isl::union_set(DimOption);
1151 isl::schedule_node polly::tileNode(isl::schedule_node Node,
1152 const char *Identifier,
1153 ArrayRef<int> TileSizes,
1154 int DefaultTileSize) {
1155 auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
1156 auto Dims = Space.dim(isl::dim::set);
1157 auto Sizes = isl::multi_val::zero(Space);
1158 std::string IdentifierString(Identifier);
1159 for (unsigned i : rangeIslSize(0, Dims)) {
1160 unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
1161 Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize));
1163 auto TileLoopMarkerStr = IdentifierString + " - Tiles";
1164 auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr);
1165 Node = Node.insert_mark(TileLoopMarker);
1166 Node = Node.child(0);
1167 Node =
1168 isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
1169 Node = Node.child(0);
1170 auto PointLoopMarkerStr = IdentifierString + " - Points";
1171 auto PointLoopMarker =
1172 isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr);
1173 Node = Node.insert_mark(PointLoopMarker);
1174 return Node.child(0);
1177 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
1178 ArrayRef<int> TileSizes,
1179 int DefaultTileSize) {
1180 Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
1181 auto Ctx = Node.ctx();
1182 return Node.as<isl::schedule_node_band>().set_ast_build_options(
1183 isl::union_set(Ctx, "{unroll[x]}"));
1186 /// Find statements and sub-loops in (possibly nested) sequences.
1187 static void
1188 collectFissionableStmts(isl::schedule_node Node,
1189 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
1190 if (isBand(Node) || isLeaf(Node)) {
1191 ScheduleStmts.push_back(Node);
1192 return;
1195 if (Node.has_children()) {
1196 isl::schedule_node C = Node.first_child();
1197 while (true) {
1198 collectFissionableStmts(C, ScheduleStmts);
1199 if (!C.has_next_sibling())
1200 break;
1201 C = C.next_sibling();
1206 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
1207 isl::ctx Ctx = BandToFission.ctx();
1208 BandToFission = removeMark(BandToFission);
1209 isl::schedule_node BandBody = BandToFission.child(0);
1211 SmallVector<isl::schedule_node> FissionableStmts;
1212 collectFissionableStmts(BandBody, FissionableStmts);
1213 size_t N = FissionableStmts.size();
1215 // Collect the domain for each of the statements that will get their own loop.
1216 isl::union_set_list DomList = isl::union_set_list(Ctx, N);
1217 for (size_t i = 0; i < N; ++i) {
1218 isl::schedule_node BodyPart = FissionableStmts[i];
1219 DomList = DomList.add(BodyPart.get_domain());
1222 // Apply the fission by copying the entire loop, but inserting a filter for
1223 // the statement domains for each fissioned loop.
1224 isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList);
1226 return Fissioned.get_schedule();
1229 isl::schedule polly::applyGreedyFusion(isl::schedule Sched,
1230 const isl::union_map &Deps) {
1231 LLVM_DEBUG(dbgs() << "Greedy loop fusion\n");
1233 GreedyFusionRewriter Rewriter;
1234 isl::schedule Result = Rewriter.visit(Sched, Deps);
1235 if (!Rewriter.AnyChange) {
1236 LLVM_DEBUG(dbgs() << "Found nothing to fuse\n");
1237 return Sched;
1240 // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops
1241 // may have been split into multiple bands.
1242 return collapseBands(Result);