1 //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Utilities for generating tiled loops for matrix operations.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Utils/MatrixUtils.h"
14 #include "llvm/Analysis/DomTreeUpdater.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/Type.h"
23 BasicBlock
*TileInfo::CreateLoop(BasicBlock
*Preheader
, BasicBlock
*Exit
,
24 Value
*Bound
, Value
*Step
, StringRef Name
,
25 IRBuilderBase
&B
, DomTreeUpdater
&DTU
, Loop
*L
,
27 LLVMContext
&Ctx
= Preheader
->getContext();
28 BasicBlock
*Header
= BasicBlock::Create(
29 Preheader
->getContext(), Name
+ ".header", Preheader
->getParent(), Exit
);
30 BasicBlock
*Body
= BasicBlock::Create(Header
->getContext(), Name
+ ".body",
31 Header
->getParent(), Exit
);
32 BasicBlock
*Latch
= BasicBlock::Create(Header
->getContext(), Name
+ ".latch",
33 Header
->getParent(), Exit
);
35 Type
*I32Ty
= Type::getInt64Ty(Ctx
);
36 BranchInst::Create(Body
, Header
);
37 BranchInst::Create(Latch
, Body
);
39 PHINode::Create(I32Ty
, 2, Name
+ ".iv", Header
->getTerminator());
40 IV
->addIncoming(ConstantInt::get(I32Ty
, 0), Preheader
);
42 B
.SetInsertPoint(Latch
);
43 Value
*Inc
= B
.CreateAdd(IV
, Step
, Name
+ ".step");
44 Value
*Cond
= B
.CreateICmpNE(Inc
, Bound
, Name
+ ".cond");
45 BranchInst::Create(Header
, Exit
, Cond
, Latch
);
46 IV
->addIncoming(Inc
, Latch
);
48 BranchInst
*PreheaderBr
= cast
<BranchInst
>(Preheader
->getTerminator());
49 BasicBlock
*Tmp
= PreheaderBr
->getSuccessor(0);
50 PreheaderBr
->setSuccessor(0, Header
);
51 DTU
.applyUpdatesPermissive({
52 {DominatorTree::Delete
, Preheader
, Tmp
},
53 {DominatorTree::Insert
, Header
, Body
},
54 {DominatorTree::Insert
, Body
, Latch
},
55 {DominatorTree::Insert
, Latch
, Header
},
56 {DominatorTree::Insert
, Latch
, Exit
},
57 {DominatorTree::Insert
, Preheader
, Header
},
60 L
->addBasicBlockToLoop(Header
, LI
);
61 L
->addBasicBlockToLoop(Body
, LI
);
62 L
->addBasicBlockToLoop(Latch
, LI
);
66 // Creates the following loop nest skeleton:
67 // for C = 0; C < NumColumns; C += TileSize
68 // for R = 0; R < NumRows; R += TileSize
69 // for K = 0; K < Inner ; K += TileSize
70 BasicBlock
*TileInfo::CreateTiledLoops(BasicBlock
*Start
, BasicBlock
*End
,
71 IRBuilderBase
&B
, DomTreeUpdater
&DTU
,
73 Loop
*ColLoop
= LI
.AllocateLoop();
74 Loop
*RowLoop
= LI
.AllocateLoop();
75 Loop
*InnerLoop
= LI
.AllocateLoop();
76 RowLoop
->addChildLoop(InnerLoop
);
77 ColLoop
->addChildLoop(RowLoop
);
78 if (Loop
*ParentL
= LI
.getLoopFor(Start
))
79 ParentL
->addChildLoop(ColLoop
);
81 LI
.addTopLevelLoop(ColLoop
);
84 CreateLoop(Start
, End
, B
.getInt64(NumColumns
), B
.getInt64(TileSize
),
85 "cols", B
, DTU
, ColLoop
, LI
);
86 BasicBlock
*ColLatch
= ColBody
->getSingleSuccessor();
88 CreateLoop(ColBody
, ColLatch
, B
.getInt64(NumRows
), B
.getInt64(TileSize
),
89 "rows", B
, DTU
, RowLoop
, LI
);
90 RowLoopLatch
= RowBody
->getSingleSuccessor();
92 BasicBlock
*InnerBody
=
93 CreateLoop(RowBody
, RowLoopLatch
, B
.getInt64(NumInner
),
94 B
.getInt64(TileSize
), "inner", B
, DTU
, InnerLoop
, LI
);
95 InnerLoopLatch
= InnerBody
->getSingleSuccessor();
96 ColumnLoopHeader
= ColBody
->getSinglePredecessor();
97 RowLoopHeader
= RowBody
->getSinglePredecessor();
98 InnerLoopHeader
= InnerBody
->getSinglePredecessor();
99 CurrentRow
= &*RowLoopHeader
->begin();
100 CurrentCol
= &*ColumnLoopHeader
->begin();
101 CurrentK
= &*InnerLoopHeader
->begin();