1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
43 class ScalarizeMaskedMemIntrin
: public FunctionPass
{
44 const TargetTransformInfo
*TTI
= nullptr;
47 static char ID
; // Pass identification, replacement for typeid
49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID
) {
50 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function
&F
) override
;
55 StringRef
getPassName() const override
{
56 return "Scalarize Masked Memory Intrinsics";
59 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
60 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
64 bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
);
65 bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
);
68 } // end anonymous namespace
70 char ScalarizeMaskedMemIntrin::ID
= 0;
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin
, DEBUG_TYPE
,
73 "Scalarize unsupported masked memory intrinsics", false, false)
75 FunctionPass
*llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
79 static bool isConstantIntVector(Value
*Mask
) {
80 Constant
*C
= dyn_cast
<Constant
>(Mask
);
84 unsigned NumElts
= Mask
->getType()->getVectorNumElements();
85 for (unsigned i
= 0; i
!= NumElts
; ++i
) {
86 Constant
*CElt
= C
->getAggregateElement(i
);
87 if (!CElt
|| !isa
<ConstantInt
>(CElt
))
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
126 static void scalarizeMaskedLoad(CallInst
*CI
, bool &ModifiedDT
) {
127 Value
*Ptr
= CI
->getArgOperand(0);
128 Value
*Alignment
= CI
->getArgOperand(1);
129 Value
*Mask
= CI
->getArgOperand(2);
130 Value
*Src0
= CI
->getArgOperand(3);
132 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
133 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
135 Type
*EltTy
= VecType
->getElementType();
137 IRBuilder
<> Builder(CI
->getContext());
138 Instruction
*InsertPt
= CI
;
139 BasicBlock
*IfBlock
= CI
->getParent();
141 Builder
.SetInsertPoint(InsertPt
);
142 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
144 // Short-cut if the mask is all-true.
145 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
146 Value
*NewI
= Builder
.CreateAlignedLoad(VecType
, Ptr
, AlignVal
);
147 CI
->replaceAllUsesWith(NewI
);
148 CI
->eraseFromParent();
152 // Adjust alignment for the scalar instruction.
153 AlignVal
= MinAlign(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
154 // Bitcast %addr from i8* to EltTy*
156 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
157 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
158 unsigned VectorWidth
= VecType
->getNumElements();
161 Value
*VResult
= Src0
;
163 if (isConstantIntVector(Mask
)) {
164 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
165 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
167 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
168 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AlignVal
);
169 VResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
171 CI
->replaceAllUsesWith(VResult
);
172 CI
->eraseFromParent();
176 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
177 // Fill the "else" block, created in the previous iteration
179 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181 // br i1 %mask_1, label %cond.load, label %else
184 Value
*Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
186 // Create "cond" block
188 // %EltAddr = getelementptr i32* %1, i32 0
189 // %Elt = load i32* %EltAddr
190 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
192 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
->getIterator(),
194 Builder
.SetInsertPoint(InsertPt
);
196 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
197 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AlignVal
);
198 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
200 // Create "else" block, fill it in the next iteration
201 BasicBlock
*NewIfBlock
=
202 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
203 Builder
.SetInsertPoint(InsertPt
);
204 Instruction
*OldBr
= IfBlock
->getTerminator();
205 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
206 OldBr
->eraseFromParent();
207 BasicBlock
*PrevIfBlock
= IfBlock
;
208 IfBlock
= NewIfBlock
;
210 // Create the phi to join the new and previous value.
211 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
212 Phi
->addIncoming(NewVResult
, CondBlock
);
213 Phi
->addIncoming(VResult
, PrevIfBlock
);
217 CI
->replaceAllUsesWith(VResult
);
218 CI
->eraseFromParent();
223 // Translate a masked store intrinsic, like
224 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
226 // to a chain of basic blocks, that stores element one-by-one if
227 // the appropriate mask bit is set
229 // %1 = bitcast i8* %addr to i32*
230 // %2 = extractelement <16 x i1> %mask, i32 0
231 // br i1 %2, label %cond.store, label %else
233 // cond.store: ; preds = %0
234 // %3 = extractelement <16 x i32> %val, i32 0
235 // %4 = getelementptr i32* %1, i32 0
236 // store i32 %3, i32* %4
239 // else: ; preds = %0, %cond.store
240 // %5 = extractelement <16 x i1> %mask, i32 1
241 // br i1 %5, label %cond.store1, label %else2
243 // cond.store1: ; preds = %else
244 // %6 = extractelement <16 x i32> %val, i32 1
245 // %7 = getelementptr i32* %1, i32 1
246 // store i32 %6, i32* %7
249 static void scalarizeMaskedStore(CallInst
*CI
, bool &ModifiedDT
) {
250 Value
*Src
= CI
->getArgOperand(0);
251 Value
*Ptr
= CI
->getArgOperand(1);
252 Value
*Alignment
= CI
->getArgOperand(2);
253 Value
*Mask
= CI
->getArgOperand(3);
255 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
256 VectorType
*VecType
= cast
<VectorType
>(Src
->getType());
258 Type
*EltTy
= VecType
->getElementType();
260 IRBuilder
<> Builder(CI
->getContext());
261 Instruction
*InsertPt
= CI
;
262 BasicBlock
*IfBlock
= CI
->getParent();
263 Builder
.SetInsertPoint(InsertPt
);
264 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
266 // Short-cut if the mask is all-true.
267 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
268 Builder
.CreateAlignedStore(Src
, Ptr
, AlignVal
);
269 CI
->eraseFromParent();
273 // Adjust alignment for the scalar instruction.
274 AlignVal
= MinAlign(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
275 // Bitcast %addr from i8* to EltTy*
277 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
278 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
279 unsigned VectorWidth
= VecType
->getNumElements();
281 if (isConstantIntVector(Mask
)) {
282 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
283 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
285 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
286 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
287 Builder
.CreateAlignedStore(OneElt
, Gep
, AlignVal
);
289 CI
->eraseFromParent();
293 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
294 // Fill the "else" block, created in the previous iteration
296 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297 // br i1 %mask_1, label %cond.store, label %else
299 Value
*Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
301 // Create "cond" block
303 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
304 // %EltAddr = getelementptr i32* %1, i32 0
305 // %store i32 %OneElt, i32* %EltAddr
307 BasicBlock
*CondBlock
=
308 IfBlock
->splitBasicBlock(InsertPt
->getIterator(), "cond.store");
309 Builder
.SetInsertPoint(InsertPt
);
311 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
312 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
313 Builder
.CreateAlignedStore(OneElt
, Gep
, AlignVal
);
315 // Create "else" block, fill it in the next iteration
316 BasicBlock
*NewIfBlock
=
317 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
318 Builder
.SetInsertPoint(InsertPt
);
319 Instruction
*OldBr
= IfBlock
->getTerminator();
320 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
321 OldBr
->eraseFromParent();
322 IfBlock
= NewIfBlock
;
324 CI
->eraseFromParent();
329 // Translate a masked gather intrinsic like
330 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331 // <16 x i1> %Mask, <16 x i32> %Src)
332 // to a chain of basic blocks, with loading element one-by-one if
333 // the appropriate mask bit is set
335 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
337 // br i1 %Mask0, label %cond.load, label %else
340 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341 // %Load0 = load i32, i32* %Ptr0, align 4
342 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
346 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
348 // br i1 %Mask1, label %cond.load1, label %else2
351 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352 // %Load1 = load i32, i32* %Ptr1, align 4
353 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
356 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357 // ret <16 x i32> %Result
358 static void scalarizeMaskedGather(CallInst
*CI
, bool &ModifiedDT
) {
359 Value
*Ptrs
= CI
->getArgOperand(0);
360 Value
*Alignment
= CI
->getArgOperand(1);
361 Value
*Mask
= CI
->getArgOperand(2);
362 Value
*Src0
= CI
->getArgOperand(3);
364 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
365 Type
*EltTy
= VecType
->getElementType();
367 IRBuilder
<> Builder(CI
->getContext());
368 Instruction
*InsertPt
= CI
;
369 BasicBlock
*IfBlock
= CI
->getParent();
370 Builder
.SetInsertPoint(InsertPt
);
371 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
373 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
376 Value
*VResult
= Src0
;
377 unsigned VectorWidth
= VecType
->getNumElements();
379 // Shorten the way if the mask is a vector of constants.
380 if (isConstantIntVector(Mask
)) {
381 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
382 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
384 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
386 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
388 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
390 CI
->replaceAllUsesWith(VResult
);
391 CI
->eraseFromParent();
395 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
396 // Fill the "else" block, created in the previous iteration
398 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
399 // br i1 %Mask1, label %cond.load, label %else
403 Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
405 // Create "cond" block
407 // %EltAddr = getelementptr i32* %1, i32 0
408 // %Elt = load i32* %EltAddr
409 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
411 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
, "cond.load");
412 Builder
.SetInsertPoint(InsertPt
);
414 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
416 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
418 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
420 // Create "else" block, fill it in the next iteration
421 BasicBlock
*NewIfBlock
= CondBlock
->splitBasicBlock(InsertPt
, "else");
422 Builder
.SetInsertPoint(InsertPt
);
423 Instruction
*OldBr
= IfBlock
->getTerminator();
424 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
425 OldBr
->eraseFromParent();
426 BasicBlock
*PrevIfBlock
= IfBlock
;
427 IfBlock
= NewIfBlock
;
429 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
430 Phi
->addIncoming(NewVResult
, CondBlock
);
431 Phi
->addIncoming(VResult
, PrevIfBlock
);
435 CI
->replaceAllUsesWith(VResult
);
436 CI
->eraseFromParent();
441 // Translate a masked scatter intrinsic, like
442 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
444 // to a chain of basic blocks, that stores element one-by-one if
445 // the appropriate mask bit is set.
447 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
449 // br i1 %Mask0, label %cond.store, label %else
452 // %Elt0 = extractelement <16 x i32> %Src, i32 0
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // store i32 %Elt0, i32* %Ptr0, align 4
458 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
459 // br i1 %Mask1, label %cond.store1, label %else2
462 // %Elt1 = extractelement <16 x i32> %Src, i32 1
463 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464 // store i32 %Elt1, i32* %Ptr1, align 4
467 static void scalarizeMaskedScatter(CallInst
*CI
, bool &ModifiedDT
) {
468 Value
*Src
= CI
->getArgOperand(0);
469 Value
*Ptrs
= CI
->getArgOperand(1);
470 Value
*Alignment
= CI
->getArgOperand(2);
471 Value
*Mask
= CI
->getArgOperand(3);
473 assert(isa
<VectorType
>(Src
->getType()) &&
474 "Unexpected data type in masked scatter intrinsic");
475 assert(isa
<VectorType
>(Ptrs
->getType()) &&
476 isa
<PointerType
>(Ptrs
->getType()->getVectorElementType()) &&
477 "Vector of pointers is expected in masked scatter intrinsic");
479 IRBuilder
<> Builder(CI
->getContext());
480 Instruction
*InsertPt
= CI
;
481 BasicBlock
*IfBlock
= CI
->getParent();
482 Builder
.SetInsertPoint(InsertPt
);
483 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
485 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
486 unsigned VectorWidth
= Src
->getType()->getVectorNumElements();
488 // Shorten the way if the mask is a vector of constants.
489 if (isConstantIntVector(Mask
)) {
490 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
491 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
494 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
495 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
496 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
498 CI
->eraseFromParent();
502 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
503 // Fill the "else" block, created in the previous iteration
505 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506 // br i1 %Mask1, label %cond.store, label %else
509 Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
511 // Create "cond" block
513 // %Elt1 = extractelement <16 x i32> %Src, i32 1
514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515 // %store i32 %Elt1, i32* %Ptr1
517 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
, "cond.store");
518 Builder
.SetInsertPoint(InsertPt
);
520 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
521 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
522 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
524 // Create "else" block, fill it in the next iteration
525 BasicBlock
*NewIfBlock
= CondBlock
->splitBasicBlock(InsertPt
, "else");
526 Builder
.SetInsertPoint(InsertPt
);
527 Instruction
*OldBr
= IfBlock
->getTerminator();
528 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
529 OldBr
->eraseFromParent();
530 IfBlock
= NewIfBlock
;
532 CI
->eraseFromParent();
537 static void scalarizeMaskedExpandLoad(CallInst
*CI
, bool &ModifiedDT
) {
538 Value
*Ptr
= CI
->getArgOperand(0);
539 Value
*Mask
= CI
->getArgOperand(1);
540 Value
*PassThru
= CI
->getArgOperand(2);
542 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
544 Type
*EltTy
= VecType
->getElementType();
546 IRBuilder
<> Builder(CI
->getContext());
547 Instruction
*InsertPt
= CI
;
548 BasicBlock
*IfBlock
= CI
->getParent();
550 Builder
.SetInsertPoint(InsertPt
);
551 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
553 unsigned VectorWidth
= VecType
->getNumElements();
556 Value
*VResult
= PassThru
;
558 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
559 // Fill the "else" block, created in the previous iteration
561 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563 // br i1 %mask_1, label %cond.load, label %else
567 Builder
.CreateExtractElement(Mask
, Idx
);
569 // Create "cond" block
571 // %EltAddr = getelementptr i32* %1, i32 0
572 // %Elt = load i32* %EltAddr
573 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
575 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
->getIterator(),
577 Builder
.SetInsertPoint(InsertPt
);
579 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Ptr
, 1);
580 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
582 // Move the pointer if there are more blocks to come.
584 if ((Idx
+ 1) != VectorWidth
)
585 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
587 // Create "else" block, fill it in the next iteration
588 BasicBlock
*NewIfBlock
=
589 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
590 Builder
.SetInsertPoint(InsertPt
);
591 Instruction
*OldBr
= IfBlock
->getTerminator();
592 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
593 OldBr
->eraseFromParent();
594 BasicBlock
*PrevIfBlock
= IfBlock
;
595 IfBlock
= NewIfBlock
;
597 // Create the phi to join the new and previous value.
598 PHINode
*ResultPhi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
599 ResultPhi
->addIncoming(NewVResult
, CondBlock
);
600 ResultPhi
->addIncoming(VResult
, PrevIfBlock
);
603 // Add a PHI for the pointer if this isn't the last iteration.
604 if ((Idx
+ 1) != VectorWidth
) {
605 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
606 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
607 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
612 CI
->replaceAllUsesWith(VResult
);
613 CI
->eraseFromParent();
618 static void scalarizeMaskedCompressStore(CallInst
*CI
, bool &ModifiedDT
) {
619 Value
*Src
= CI
->getArgOperand(0);
620 Value
*Ptr
= CI
->getArgOperand(1);
621 Value
*Mask
= CI
->getArgOperand(2);
623 VectorType
*VecType
= cast
<VectorType
>(Src
->getType());
625 IRBuilder
<> Builder(CI
->getContext());
626 Instruction
*InsertPt
= CI
;
627 BasicBlock
*IfBlock
= CI
->getParent();
629 Builder
.SetInsertPoint(InsertPt
);
630 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
632 Type
*EltTy
= VecType
->getVectorElementType();
634 unsigned VectorWidth
= VecType
->getNumElements();
636 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
637 // Fill the "else" block, created in the previous iteration
639 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640 // br i1 %mask_1, label %cond.store, label %else
642 Value
*Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
644 // Create "cond" block
646 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
647 // %EltAddr = getelementptr i32* %1, i32 0
648 // %store i32 %OneElt, i32* %EltAddr
650 BasicBlock
*CondBlock
=
651 IfBlock
->splitBasicBlock(InsertPt
->getIterator(), "cond.store");
652 Builder
.SetInsertPoint(InsertPt
);
654 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
655 Builder
.CreateAlignedStore(OneElt
, Ptr
, 1);
657 // Move the pointer if there are more blocks to come.
659 if ((Idx
+ 1) != VectorWidth
)
660 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
662 // Create "else" block, fill it in the next iteration
663 BasicBlock
*NewIfBlock
=
664 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
665 Builder
.SetInsertPoint(InsertPt
);
666 Instruction
*OldBr
= IfBlock
->getTerminator();
667 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
668 OldBr
->eraseFromParent();
669 BasicBlock
*PrevIfBlock
= IfBlock
;
670 IfBlock
= NewIfBlock
;
672 // Add a PHI for the pointer if this isn't the last iteration.
673 if ((Idx
+ 1) != VectorWidth
) {
674 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
675 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
676 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
680 CI
->eraseFromParent();
685 bool ScalarizeMaskedMemIntrin::runOnFunction(Function
&F
) {
686 bool EverMadeChange
= false;
688 TTI
= &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
690 bool MadeChange
= true;
693 for (Function::iterator I
= F
.begin(); I
!= F
.end();) {
694 BasicBlock
*BB
= &*I
++;
695 bool ModifiedDTOnIteration
= false;
696 MadeChange
|= optimizeBlock(*BB
, ModifiedDTOnIteration
);
698 // Restart BB iteration if the dominator tree of the Function was changed
699 if (ModifiedDTOnIteration
)
703 EverMadeChange
|= MadeChange
;
706 return EverMadeChange
;
709 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
) {
710 bool MadeChange
= false;
712 BasicBlock::iterator CurInstIterator
= BB
.begin();
713 while (CurInstIterator
!= BB
.end()) {
714 if (CallInst
*CI
= dyn_cast
<CallInst
>(&*CurInstIterator
++))
715 MadeChange
|= optimizeCallInst(CI
, ModifiedDT
);
723 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst
*CI
,
725 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
);
727 switch (II
->getIntrinsicID()) {
730 case Intrinsic::masked_load
:
731 // Scalarize unsupported vector masked load
732 if (TTI
->isLegalMaskedLoad(CI
->getType()))
734 scalarizeMaskedLoad(CI
, ModifiedDT
);
736 case Intrinsic::masked_store
:
737 if (TTI
->isLegalMaskedStore(CI
->getArgOperand(0)->getType()))
739 scalarizeMaskedStore(CI
, ModifiedDT
);
741 case Intrinsic::masked_gather
:
742 if (TTI
->isLegalMaskedGather(CI
->getType()))
744 scalarizeMaskedGather(CI
, ModifiedDT
);
746 case Intrinsic::masked_scatter
:
747 if (TTI
->isLegalMaskedScatter(CI
->getArgOperand(0)->getType()))
749 scalarizeMaskedScatter(CI
, ModifiedDT
);
751 case Intrinsic::masked_expandload
:
752 if (TTI
->isLegalMaskedExpandLoad(CI
->getType()))
754 scalarizeMaskedExpandLoad(CI
, ModifiedDT
);
756 case Intrinsic::masked_compressstore
:
757 if (TTI
->isLegalMaskedCompressStore(CI
->getArgOperand(0)->getType()))
759 scalarizeMaskedCompressStore(CI
, ModifiedDT
);