[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / polly / lib / CodeGen / LoopGeneratorsKMP.cpp
blobb3af7b14f478082ac58be18a9144f20ee19ef0ba
1 //===------ LoopGeneratorsKMP.cpp - IR helper to create loops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains functions to create parallel loops as LLVM-IR.
11 //===----------------------------------------------------------------------===//
13 #include "polly/CodeGen/LoopGeneratorsKMP.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/Module.h"
17 using namespace llvm;
18 using namespace polly;
20 void ParallelLoopGeneratorKMP::createCallSpawnThreads(Value *SubFn,
21 Value *SubFnParam,
22 Value *LB, Value *UB,
23 Value *Stride) {
24 const std::string Name = "__kmpc_fork_call";
25 Function *F = M->getFunction(Name);
26 Type *KMPCMicroTy = StructType::getTypeByName(M->getContext(), "kmpc_micro");
28 if (!KMPCMicroTy) {
29 // void (*kmpc_micro)(kmp_int32 *global_tid, kmp_int32 *bound_tid, ...)
30 Type *MicroParams[] = {Builder.getInt32Ty()->getPointerTo(),
31 Builder.getInt32Ty()->getPointerTo()};
33 KMPCMicroTy = FunctionType::get(Builder.getVoidTy(), MicroParams, true);
36 // If F is not available, declare it.
37 if (!F) {
38 StructType *IdentTy =
39 StructType::getTypeByName(M->getContext(), "struct.ident_t");
41 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
42 Type *Params[] = {IdentTy->getPointerTo(), Builder.getInt32Ty(),
43 KMPCMicroTy->getPointerTo()};
45 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, true);
46 F = Function::Create(Ty, Linkage, Name, M);
49 Value *Task = Builder.CreatePointerBitCastOrAddrSpaceCast(
50 SubFn, KMPCMicroTy->getPointerTo());
52 Value *Args[] = {SourceLocationInfo,
53 Builder.getInt32(4) /* Number of arguments (w/o Task) */,
54 Task,
55 LB,
56 UB,
57 Stride,
58 SubFnParam};
60 CallInst *Call = Builder.CreateCall(F, Args);
61 Call->setDebugLoc(DLGenerated);
64 void ParallelLoopGeneratorKMP::deployParallelExecution(Function *SubFn,
65 Value *SubFnParam,
66 Value *LB, Value *UB,
67 Value *Stride) {
68 // Inform OpenMP runtime about the number of threads if greater than zero
69 if (PollyNumThreads > 0) {
70 Value *GlobalThreadID = createCallGlobalThreadNum();
71 createCallPushNumThreads(GlobalThreadID, Builder.getInt32(PollyNumThreads));
74 // Tell the runtime we start a parallel loop
75 createCallSpawnThreads(SubFn, SubFnParam, LB, UB, Stride);
78 Function *ParallelLoopGeneratorKMP::prepareSubFnDefinition(Function *F) const {
79 std::vector<Type *> Arguments = {Builder.getInt32Ty()->getPointerTo(),
80 Builder.getInt32Ty()->getPointerTo(),
81 LongType,
82 LongType,
83 LongType,
84 Builder.getPtrTy()};
86 FunctionType *FT = FunctionType::get(Builder.getVoidTy(), Arguments, false);
87 Function *SubFn = Function::Create(FT, Function::InternalLinkage,
88 F->getName() + "_polly_subfn", M);
89 // Name the function's arguments
90 Function::arg_iterator AI = SubFn->arg_begin();
91 AI->setName("polly.kmpc.global_tid");
92 std::advance(AI, 1);
93 AI->setName("polly.kmpc.bound_tid");
94 std::advance(AI, 1);
95 AI->setName("polly.kmpc.lb");
96 std::advance(AI, 1);
97 AI->setName("polly.kmpc.ub");
98 std::advance(AI, 1);
99 AI->setName("polly.kmpc.inc");
100 std::advance(AI, 1);
101 AI->setName("polly.kmpc.shared");
103 return SubFn;
106 // Create a subfunction of the following (preliminary) structure:
108 // PrevBB
109 // |
110 // v
111 // HeaderBB
112 // / | _____
113 // / v v |
114 // / PreHeaderBB |
115 // | | |
116 // | v |
117 // | CheckNextBB |
118 // \ | \_____/
119 // \ |
120 // v v
121 // ExitBB
123 // HeaderBB will hold allocations, loading of variables and kmp-init calls.
124 // CheckNextBB will check for more work (dynamic / static chunked) or will be
125 // empty (static non chunked).
126 // If there is more work to do: go to PreHeaderBB, otherwise go to ExitBB.
127 // PreHeaderBB loads the new boundaries (& will lead to the loop body later on).
128 // Just like CheckNextBB: PreHeaderBB is (preliminary) empty in the static non
129 // chunked scheduling case. ExitBB marks the end of the parallel execution.
130 // The possibly empty BasicBlocks will automatically be removed.
131 std::tuple<Value *, Function *>
132 ParallelLoopGeneratorKMP::createSubFn(Value *SequentialLoopStride,
133 AllocaInst *StructData,
134 SetVector<Value *> Data, ValueMapT &Map) {
135 Function *SubFn = createSubFnDefinition();
136 LLVMContext &Context = SubFn->getContext();
138 // Store the previous basic block.
139 BasicBlock *PrevBB = Builder.GetInsertBlock();
141 // Create basic blocks.
142 BasicBlock *HeaderBB = BasicBlock::Create(Context, "polly.par.setup", SubFn);
143 BasicBlock *ExitBB = BasicBlock::Create(Context, "polly.par.exit", SubFn);
144 BasicBlock *CheckNextBB =
145 BasicBlock::Create(Context, "polly.par.checkNext", SubFn);
146 BasicBlock *PreHeaderBB =
147 BasicBlock::Create(Context, "polly.par.loadIVBounds", SubFn);
149 DT.addNewBlock(HeaderBB, PrevBB);
150 DT.addNewBlock(ExitBB, HeaderBB);
151 DT.addNewBlock(CheckNextBB, HeaderBB);
152 DT.addNewBlock(PreHeaderBB, HeaderBB);
154 // Fill up basic block HeaderBB.
155 Builder.SetInsertPoint(HeaderBB);
156 Value *LBPtr = Builder.CreateAlloca(LongType, nullptr, "polly.par.LBPtr");
157 Value *UBPtr = Builder.CreateAlloca(LongType, nullptr, "polly.par.UBPtr");
158 Value *IsLastPtr = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
159 "polly.par.lastIterPtr");
160 Value *StridePtr =
161 Builder.CreateAlloca(LongType, nullptr, "polly.par.StridePtr");
163 // Get iterator for retrieving the previously defined parameters.
164 Function::arg_iterator AI = SubFn->arg_begin();
165 // First argument holds "global thread ID".
166 Value *IDPtr = &*AI;
167 // Skip "bound thread ID" since it is not used (but had to be defined).
168 std::advance(AI, 2);
169 // Move iterator to: LB, UB, Stride, Shared variable struct.
170 Value *LB = &*AI;
171 std::advance(AI, 1);
172 Value *UB = &*AI;
173 std::advance(AI, 1);
174 Value *Stride = &*AI;
175 std::advance(AI, 1);
176 Value *Shared = &*AI;
178 extractValuesFromStruct(Data, StructData->getAllocatedType(), Shared, Map);
180 const auto Alignment = llvm::Align(is64BitArch() ? 8 : 4);
181 Value *ID = Builder.CreateAlignedLoad(Builder.getInt32Ty(), IDPtr, Alignment,
182 "polly.par.global_tid");
184 Builder.CreateAlignedStore(LB, LBPtr, Alignment);
185 Builder.CreateAlignedStore(UB, UBPtr, Alignment);
186 Builder.CreateAlignedStore(Builder.getInt32(0), IsLastPtr, Alignment);
187 Builder.CreateAlignedStore(Stride, StridePtr, Alignment);
189 // Subtract one as the upper bound provided by openmp is a < comparison
190 // whereas the codegenForSequential function creates a <= comparison.
191 Value *AdjustedUB = Builder.CreateAdd(UB, ConstantInt::get(LongType, -1),
192 "polly.indvar.UBAdjusted");
194 Value *ChunkSize =
195 ConstantInt::get(LongType, std::max<int>(PollyChunkSize, 1));
197 OMPGeneralSchedulingType Scheduling =
198 getSchedType(PollyChunkSize, PollyScheduling);
200 switch (Scheduling) {
201 case OMPGeneralSchedulingType::Dynamic:
202 case OMPGeneralSchedulingType::Guided:
203 case OMPGeneralSchedulingType::Runtime:
204 // "DYNAMIC" scheduling types are handled below (including 'runtime')
206 UB = AdjustedUB;
207 createCallDispatchInit(ID, LB, UB, Stride, ChunkSize);
208 Value *HasWork =
209 createCallDispatchNext(ID, IsLastPtr, LBPtr, UBPtr, StridePtr);
210 Value *HasIteration =
211 Builder.CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ, HasWork,
212 Builder.getInt32(1), "polly.hasIteration");
213 Builder.CreateCondBr(HasIteration, PreHeaderBB, ExitBB);
215 Builder.SetInsertPoint(CheckNextBB);
216 HasWork = createCallDispatchNext(ID, IsLastPtr, LBPtr, UBPtr, StridePtr);
217 HasIteration =
218 Builder.CreateICmp(llvm::CmpInst::Predicate::ICMP_EQ, HasWork,
219 Builder.getInt32(1), "polly.hasWork");
220 Builder.CreateCondBr(HasIteration, PreHeaderBB, ExitBB);
222 Builder.SetInsertPoint(PreHeaderBB);
223 LB = Builder.CreateAlignedLoad(LongType, LBPtr, Alignment,
224 "polly.indvar.LB");
225 UB = Builder.CreateAlignedLoad(LongType, UBPtr, Alignment,
226 "polly.indvar.UB");
228 break;
229 case OMPGeneralSchedulingType::StaticChunked:
230 case OMPGeneralSchedulingType::StaticNonChunked:
231 // "STATIC" scheduling types are handled below
233 Builder.CreateAlignedStore(AdjustedUB, UBPtr, Alignment);
234 createCallStaticInit(ID, IsLastPtr, LBPtr, UBPtr, StridePtr, ChunkSize);
236 Value *ChunkedStride = Builder.CreateAlignedLoad(
237 LongType, StridePtr, Alignment, "polly.kmpc.stride");
239 LB = Builder.CreateAlignedLoad(LongType, LBPtr, Alignment,
240 "polly.indvar.LB");
241 UB = Builder.CreateAlignedLoad(LongType, UBPtr, Alignment,
242 "polly.indvar.UB.temp");
244 Value *UBInRange =
245 Builder.CreateICmp(llvm::CmpInst::Predicate::ICMP_SLE, UB, AdjustedUB,
246 "polly.indvar.UB.inRange");
247 UB = Builder.CreateSelect(UBInRange, UB, AdjustedUB, "polly.indvar.UB");
248 Builder.CreateAlignedStore(UB, UBPtr, Alignment);
250 Value *HasIteration = Builder.CreateICmp(
251 llvm::CmpInst::Predicate::ICMP_SLE, LB, UB, "polly.hasIteration");
252 Builder.CreateCondBr(HasIteration, PreHeaderBB, ExitBB);
254 if (Scheduling == OMPGeneralSchedulingType::StaticChunked) {
255 Builder.SetInsertPoint(PreHeaderBB);
256 LB = Builder.CreateAlignedLoad(LongType, LBPtr, Alignment,
257 "polly.indvar.LB.entry");
258 UB = Builder.CreateAlignedLoad(LongType, UBPtr, Alignment,
259 "polly.indvar.UB.entry");
262 Builder.SetInsertPoint(CheckNextBB);
264 if (Scheduling == OMPGeneralSchedulingType::StaticChunked) {
265 Value *NextLB =
266 Builder.CreateAdd(LB, ChunkedStride, "polly.indvar.nextLB");
267 Value *NextUB = Builder.CreateAdd(UB, ChunkedStride);
269 Value *NextUBOutOfBounds =
270 Builder.CreateICmp(llvm::CmpInst::Predicate::ICMP_SGT, NextUB,
271 AdjustedUB, "polly.indvar.nextUB.outOfBounds");
272 NextUB = Builder.CreateSelect(NextUBOutOfBounds, AdjustedUB, NextUB,
273 "polly.indvar.nextUB");
275 Builder.CreateAlignedStore(NextLB, LBPtr, Alignment);
276 Builder.CreateAlignedStore(NextUB, UBPtr, Alignment);
278 Value *HasWork =
279 Builder.CreateICmp(llvm::CmpInst::Predicate::ICMP_SLE, NextLB,
280 AdjustedUB, "polly.hasWork");
281 Builder.CreateCondBr(HasWork, PreHeaderBB, ExitBB);
282 } else {
283 Builder.CreateBr(ExitBB);
286 Builder.SetInsertPoint(PreHeaderBB);
288 break;
291 Builder.CreateBr(CheckNextBB);
292 Builder.SetInsertPoint(&*--Builder.GetInsertPoint());
293 BasicBlock *AfterBB;
294 Value *IV = createLoop(LB, UB, SequentialLoopStride, Builder, LI, DT, AfterBB,
295 ICmpInst::ICMP_SLE, nullptr, true,
296 /* UseGuard */ false);
298 BasicBlock::iterator LoopBody = Builder.GetInsertPoint();
300 // Add code to terminate this subfunction.
301 Builder.SetInsertPoint(ExitBB);
302 // Static (i.e. non-dynamic) scheduling types, are terminated with a fini-call
303 if (Scheduling == OMPGeneralSchedulingType::StaticChunked ||
304 Scheduling == OMPGeneralSchedulingType::StaticNonChunked) {
305 createCallStaticFini(ID);
307 Builder.CreateRetVoid();
308 Builder.SetInsertPoint(&*LoopBody);
310 return std::make_tuple(IV, SubFn);
313 Value *ParallelLoopGeneratorKMP::createCallGlobalThreadNum() {
314 const std::string Name = "__kmpc_global_thread_num";
315 Function *F = M->getFunction(Name);
317 // If F is not available, declare it.
318 if (!F) {
319 StructType *IdentTy =
320 StructType::getTypeByName(M->getContext(), "struct.ident_t");
322 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
323 Type *Params[] = {IdentTy->getPointerTo()};
325 FunctionType *Ty = FunctionType::get(Builder.getInt32Ty(), Params, false);
326 F = Function::Create(Ty, Linkage, Name, M);
329 CallInst *Call = Builder.CreateCall(F, {SourceLocationInfo});
330 Call->setDebugLoc(DLGenerated);
331 return Call;
334 void ParallelLoopGeneratorKMP::createCallPushNumThreads(Value *GlobalThreadID,
335 Value *NumThreads) {
336 const std::string Name = "__kmpc_push_num_threads";
337 Function *F = M->getFunction(Name);
339 // If F is not available, declare it.
340 if (!F) {
341 StructType *IdentTy =
342 StructType::getTypeByName(M->getContext(), "struct.ident_t");
344 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
345 Type *Params[] = {IdentTy->getPointerTo(), Builder.getInt32Ty(),
346 Builder.getInt32Ty()};
348 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
349 F = Function::Create(Ty, Linkage, Name, M);
352 Value *Args[] = {SourceLocationInfo, GlobalThreadID, NumThreads};
354 CallInst *Call = Builder.CreateCall(F, Args);
355 Call->setDebugLoc(DLGenerated);
358 void ParallelLoopGeneratorKMP::createCallStaticInit(Value *GlobalThreadID,
359 Value *IsLastPtr,
360 Value *LBPtr, Value *UBPtr,
361 Value *StridePtr,
362 Value *ChunkSize) {
363 const std::string Name =
364 is64BitArch() ? "__kmpc_for_static_init_8" : "__kmpc_for_static_init_4";
365 Function *F = M->getFunction(Name);
366 StructType *IdentTy =
367 StructType::getTypeByName(M->getContext(), "struct.ident_t");
369 // If F is not available, declare it.
370 if (!F) {
371 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
373 Type *Params[] = {IdentTy->getPointerTo(),
374 Builder.getInt32Ty(),
375 Builder.getInt32Ty(),
376 Builder.getInt32Ty()->getPointerTo(),
377 LongType->getPointerTo(),
378 LongType->getPointerTo(),
379 LongType->getPointerTo(),
380 LongType,
381 LongType};
383 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
384 F = Function::Create(Ty, Linkage, Name, M);
387 // The parameter 'ChunkSize' will hold strictly positive integer values,
388 // regardless of PollyChunkSize's value
389 Value *Args[] = {
390 SourceLocationInfo,
391 GlobalThreadID,
392 Builder.getInt32(int(getSchedType(PollyChunkSize, PollyScheduling))),
393 IsLastPtr,
394 LBPtr,
395 UBPtr,
396 StridePtr,
397 ConstantInt::get(LongType, 1),
398 ChunkSize};
400 CallInst *Call = Builder.CreateCall(F, Args);
401 Call->setDebugLoc(DLGenerated);
404 void ParallelLoopGeneratorKMP::createCallStaticFini(Value *GlobalThreadID) {
405 const std::string Name = "__kmpc_for_static_fini";
406 Function *F = M->getFunction(Name);
407 StructType *IdentTy =
408 StructType::getTypeByName(M->getContext(), "struct.ident_t");
410 // If F is not available, declare it.
411 if (!F) {
412 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
413 Type *Params[] = {IdentTy->getPointerTo(), Builder.getInt32Ty()};
414 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
415 F = Function::Create(Ty, Linkage, Name, M);
418 Value *Args[] = {SourceLocationInfo, GlobalThreadID};
420 CallInst *Call = Builder.CreateCall(F, Args);
421 Call->setDebugLoc(DLGenerated);
424 void ParallelLoopGeneratorKMP::createCallDispatchInit(Value *GlobalThreadID,
425 Value *LB, Value *UB,
426 Value *Inc,
427 Value *ChunkSize) {
428 const std::string Name =
429 is64BitArch() ? "__kmpc_dispatch_init_8" : "__kmpc_dispatch_init_4";
430 Function *F = M->getFunction(Name);
431 StructType *IdentTy =
432 StructType::getTypeByName(M->getContext(), "struct.ident_t");
434 // If F is not available, declare it.
435 if (!F) {
436 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
438 Type *Params[] = {IdentTy->getPointerTo(),
439 Builder.getInt32Ty(),
440 Builder.getInt32Ty(),
441 LongType,
442 LongType,
443 LongType,
444 LongType};
446 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
447 F = Function::Create(Ty, Linkage, Name, M);
450 // The parameter 'ChunkSize' will hold strictly positive integer values,
451 // regardless of PollyChunkSize's value
452 Value *Args[] = {
453 SourceLocationInfo,
454 GlobalThreadID,
455 Builder.getInt32(int(getSchedType(PollyChunkSize, PollyScheduling))),
458 Inc,
459 ChunkSize};
461 CallInst *Call = Builder.CreateCall(F, Args);
462 Call->setDebugLoc(DLGenerated);
465 Value *ParallelLoopGeneratorKMP::createCallDispatchNext(Value *GlobalThreadID,
466 Value *IsLastPtr,
467 Value *LBPtr,
468 Value *UBPtr,
469 Value *StridePtr) {
470 const std::string Name =
471 is64BitArch() ? "__kmpc_dispatch_next_8" : "__kmpc_dispatch_next_4";
472 Function *F = M->getFunction(Name);
473 StructType *IdentTy =
474 StructType::getTypeByName(M->getContext(), "struct.ident_t");
476 // If F is not available, declare it.
477 if (!F) {
478 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
480 Type *Params[] = {IdentTy->getPointerTo(),
481 Builder.getInt32Ty(),
482 Builder.getInt32Ty()->getPointerTo(),
483 LongType->getPointerTo(),
484 LongType->getPointerTo(),
485 LongType->getPointerTo()};
487 FunctionType *Ty = FunctionType::get(Builder.getInt32Ty(), Params, false);
488 F = Function::Create(Ty, Linkage, Name, M);
491 Value *Args[] = {SourceLocationInfo, GlobalThreadID, IsLastPtr, LBPtr, UBPtr,
492 StridePtr};
494 CallInst *Call = Builder.CreateCall(F, Args);
495 Call->setDebugLoc(DLGenerated);
496 return Call;
499 // TODO: This function currently creates a source location dummy. It might be
500 // necessary to (actually) provide information, in the future.
501 GlobalVariable *ParallelLoopGeneratorKMP::createSourceLocation() {
502 const std::string LocName = ".loc.dummy";
503 GlobalVariable *SourceLocDummy = M->getGlobalVariable(LocName);
505 if (SourceLocDummy == nullptr) {
506 const std::string StructName = "struct.ident_t";
507 StructType *IdentTy =
508 StructType::getTypeByName(M->getContext(), StructName);
510 // If the ident_t StructType is not available, declare it.
511 // in LLVM-IR: ident_t = type { i32, i32, i32, i32, i8* }
512 if (!IdentTy) {
513 Type *LocMembers[] = {Builder.getInt32Ty(), Builder.getInt32Ty(),
514 Builder.getInt32Ty(), Builder.getInt32Ty(),
515 Builder.getPtrTy()};
517 IdentTy =
518 StructType::create(M->getContext(), LocMembers, StructName, false);
521 const auto ArrayType =
522 llvm::ArrayType::get(Builder.getInt8Ty(), /* Length */ 23);
524 // Global Variable Definitions
525 GlobalVariable *StrVar =
526 new GlobalVariable(*M, ArrayType, true, GlobalValue::PrivateLinkage,
527 nullptr, ".str.ident");
528 StrVar->setAlignment(llvm::Align(1));
530 SourceLocDummy = new GlobalVariable(
531 *M, IdentTy, true, GlobalValue::PrivateLinkage, nullptr, LocName);
532 SourceLocDummy->setAlignment(llvm::Align(8));
534 // Constant Definitions
535 Constant *InitStr = ConstantDataArray::getString(
536 M->getContext(), "Source location dummy.", true);
538 Constant *StrPtr = static_cast<Constant *>(Builder.CreateInBoundsGEP(
539 ArrayType, StrVar, {Builder.getInt32(0), Builder.getInt32(0)}));
541 Constant *LocInitStruct = ConstantStruct::get(
542 IdentTy, {Builder.getInt32(0), Builder.getInt32(0), Builder.getInt32(0),
543 Builder.getInt32(0), StrPtr});
545 // Initialize variables
546 StrVar->setInitializer(InitStr);
547 SourceLocDummy->setInitializer(LocInitStruct);
550 return SourceLocDummy;
553 bool ParallelLoopGeneratorKMP::is64BitArch() {
554 return (LongType->getIntegerBitWidth() == 64);
557 OMPGeneralSchedulingType ParallelLoopGeneratorKMP::getSchedType(
558 int ChunkSize, OMPGeneralSchedulingType Scheduling) const {
559 if (ChunkSize == 0 && Scheduling == OMPGeneralSchedulingType::StaticChunked)
560 return OMPGeneralSchedulingType::StaticNonChunked;
562 return Scheduling;