1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
43 class ScalarizeMaskedMemIntrin
: public FunctionPass
{
44 const TargetTransformInfo
*TTI
= nullptr;
47 static char ID
; // Pass identification, replacement for typeid
49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID
) {
50 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function
&F
) override
;
55 StringRef
getPassName() const override
{
56 return "Scalarize Masked Memory Intrinsics";
59 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
60 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
64 bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
);
65 bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
);
68 } // end anonymous namespace
70 char ScalarizeMaskedMemIntrin::ID
= 0;
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin
, DEBUG_TYPE
,
73 "Scalarize unsupported masked memory intrinsics", false, false)
75 FunctionPass
*llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
79 static bool isConstantIntVector(Value
*Mask
) {
80 Constant
*C
= dyn_cast
<Constant
>(Mask
);
84 unsigned NumElts
= Mask
->getType()->getVectorNumElements();
85 for (unsigned i
= 0; i
!= NumElts
; ++i
) {
86 Constant
*CElt
= C
->getAggregateElement(i
);
87 if (!CElt
|| !isa
<ConstantInt
>(CElt
))
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
126 static void scalarizeMaskedLoad(CallInst
*CI
, bool &ModifiedDT
) {
127 Value
*Ptr
= CI
->getArgOperand(0);
128 Value
*Alignment
= CI
->getArgOperand(1);
129 Value
*Mask
= CI
->getArgOperand(2);
130 Value
*Src0
= CI
->getArgOperand(3);
132 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
133 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
135 Type
*EltTy
= VecType
->getElementType();
137 IRBuilder
<> Builder(CI
->getContext());
138 Instruction
*InsertPt
= CI
;
139 BasicBlock
*IfBlock
= CI
->getParent();
141 Builder
.SetInsertPoint(InsertPt
);
142 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
144 // Short-cut if the mask is all-true.
145 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
146 Value
*NewI
= Builder
.CreateAlignedLoad(VecType
, Ptr
, AlignVal
);
147 CI
->replaceAllUsesWith(NewI
);
148 CI
->eraseFromParent();
152 // Adjust alignment for the scalar instruction.
153 AlignVal
= MinAlign(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
154 // Bitcast %addr from i8* to EltTy*
156 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
157 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
158 unsigned VectorWidth
= VecType
->getNumElements();
161 Value
*VResult
= Src0
;
163 if (isConstantIntVector(Mask
)) {
164 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
165 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
167 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
168 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AlignVal
);
169 VResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
171 CI
->replaceAllUsesWith(VResult
);
172 CI
->eraseFromParent();
176 // If the mask is not v1i1, use scalar bit test operations. This generates
177 // better results on X86 at least.
179 if (VectorWidth
!= 1) {
180 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
181 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
184 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
185 // Fill the "else" block, created in the previous iteration
187 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
188 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
189 // %cond = icmp ne i16 %mask_1, 0
190 // br i1 %mask_1, label %cond.load, label %else
193 if (VectorWidth
!= 1) {
194 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
195 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
196 Builder
.getIntN(VectorWidth
, 0));
198 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
201 // Create "cond" block
203 // %EltAddr = getelementptr i32* %1, i32 0
204 // %Elt = load i32* %EltAddr
205 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
207 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
->getIterator(),
209 Builder
.SetInsertPoint(InsertPt
);
211 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
212 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AlignVal
);
213 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
215 // Create "else" block, fill it in the next iteration
216 BasicBlock
*NewIfBlock
=
217 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
218 Builder
.SetInsertPoint(InsertPt
);
219 Instruction
*OldBr
= IfBlock
->getTerminator();
220 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
221 OldBr
->eraseFromParent();
222 BasicBlock
*PrevIfBlock
= IfBlock
;
223 IfBlock
= NewIfBlock
;
225 // Create the phi to join the new and previous value.
226 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
227 Phi
->addIncoming(NewVResult
, CondBlock
);
228 Phi
->addIncoming(VResult
, PrevIfBlock
);
232 CI
->replaceAllUsesWith(VResult
);
233 CI
->eraseFromParent();
238 // Translate a masked store intrinsic, like
239 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
241 // to a chain of basic blocks, that stores element one-by-one if
242 // the appropriate mask bit is set
244 // %1 = bitcast i8* %addr to i32*
245 // %2 = extractelement <16 x i1> %mask, i32 0
246 // br i1 %2, label %cond.store, label %else
248 // cond.store: ; preds = %0
249 // %3 = extractelement <16 x i32> %val, i32 0
250 // %4 = getelementptr i32* %1, i32 0
251 // store i32 %3, i32* %4
254 // else: ; preds = %0, %cond.store
255 // %5 = extractelement <16 x i1> %mask, i32 1
256 // br i1 %5, label %cond.store1, label %else2
258 // cond.store1: ; preds = %else
259 // %6 = extractelement <16 x i32> %val, i32 1
260 // %7 = getelementptr i32* %1, i32 1
261 // store i32 %6, i32* %7
264 static void scalarizeMaskedStore(CallInst
*CI
, bool &ModifiedDT
) {
265 Value
*Src
= CI
->getArgOperand(0);
266 Value
*Ptr
= CI
->getArgOperand(1);
267 Value
*Alignment
= CI
->getArgOperand(2);
268 Value
*Mask
= CI
->getArgOperand(3);
270 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
271 VectorType
*VecType
= cast
<VectorType
>(Src
->getType());
273 Type
*EltTy
= VecType
->getElementType();
275 IRBuilder
<> Builder(CI
->getContext());
276 Instruction
*InsertPt
= CI
;
277 BasicBlock
*IfBlock
= CI
->getParent();
278 Builder
.SetInsertPoint(InsertPt
);
279 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
281 // Short-cut if the mask is all-true.
282 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
283 Builder
.CreateAlignedStore(Src
, Ptr
, AlignVal
);
284 CI
->eraseFromParent();
288 // Adjust alignment for the scalar instruction.
289 AlignVal
= MinAlign(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
290 // Bitcast %addr from i8* to EltTy*
292 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
293 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
294 unsigned VectorWidth
= VecType
->getNumElements();
296 if (isConstantIntVector(Mask
)) {
297 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
298 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
300 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
301 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
302 Builder
.CreateAlignedStore(OneElt
, Gep
, AlignVal
);
304 CI
->eraseFromParent();
308 // If the mask is not v1i1, use scalar bit test operations. This generates
309 // better results on X86 at least.
311 if (VectorWidth
!= 1) {
312 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
313 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
316 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
317 // Fill the "else" block, created in the previous iteration
319 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
320 // %cond = icmp ne i16 %mask_1, 0
321 // br i1 %mask_1, label %cond.store, label %else
324 if (VectorWidth
!= 1) {
325 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
326 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
327 Builder
.getIntN(VectorWidth
, 0));
329 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
332 // Create "cond" block
334 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
335 // %EltAddr = getelementptr i32* %1, i32 0
336 // %store i32 %OneElt, i32* %EltAddr
338 BasicBlock
*CondBlock
=
339 IfBlock
->splitBasicBlock(InsertPt
->getIterator(), "cond.store");
340 Builder
.SetInsertPoint(InsertPt
);
342 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
343 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
344 Builder
.CreateAlignedStore(OneElt
, Gep
, AlignVal
);
346 // Create "else" block, fill it in the next iteration
347 BasicBlock
*NewIfBlock
=
348 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
349 Builder
.SetInsertPoint(InsertPt
);
350 Instruction
*OldBr
= IfBlock
->getTerminator();
351 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
352 OldBr
->eraseFromParent();
353 IfBlock
= NewIfBlock
;
355 CI
->eraseFromParent();
360 // Translate a masked gather intrinsic like
361 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
362 // <16 x i1> %Mask, <16 x i32> %Src)
363 // to a chain of basic blocks, with loading element one-by-one if
364 // the appropriate mask bit is set
366 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
367 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
368 // br i1 %Mask0, label %cond.load, label %else
371 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
372 // %Load0 = load i32, i32* %Ptr0, align 4
373 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
377 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
378 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
379 // br i1 %Mask1, label %cond.load1, label %else2
382 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
383 // %Load1 = load i32, i32* %Ptr1, align 4
384 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
387 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
388 // ret <16 x i32> %Result
389 static void scalarizeMaskedGather(CallInst
*CI
, bool &ModifiedDT
) {
390 Value
*Ptrs
= CI
->getArgOperand(0);
391 Value
*Alignment
= CI
->getArgOperand(1);
392 Value
*Mask
= CI
->getArgOperand(2);
393 Value
*Src0
= CI
->getArgOperand(3);
395 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
396 Type
*EltTy
= VecType
->getElementType();
398 IRBuilder
<> Builder(CI
->getContext());
399 Instruction
*InsertPt
= CI
;
400 BasicBlock
*IfBlock
= CI
->getParent();
401 Builder
.SetInsertPoint(InsertPt
);
402 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
404 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
407 Value
*VResult
= Src0
;
408 unsigned VectorWidth
= VecType
->getNumElements();
410 // Shorten the way if the mask is a vector of constants.
411 if (isConstantIntVector(Mask
)) {
412 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
413 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
415 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
417 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
419 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
421 CI
->replaceAllUsesWith(VResult
);
422 CI
->eraseFromParent();
426 // If the mask is not v1i1, use scalar bit test operations. This generates
427 // better results on X86 at least.
429 if (VectorWidth
!= 1) {
430 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
431 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
434 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
435 // Fill the "else" block, created in the previous iteration
437 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
438 // %cond = icmp ne i16 %mask_1, 0
439 // br i1 %Mask1, label %cond.load, label %else
443 if (VectorWidth
!= 1) {
444 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
445 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
446 Builder
.getIntN(VectorWidth
, 0));
448 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
451 // Create "cond" block
453 // %EltAddr = getelementptr i32* %1, i32 0
454 // %Elt = load i32* %EltAddr
455 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
457 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
, "cond.load");
458 Builder
.SetInsertPoint(InsertPt
);
460 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
462 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
464 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
466 // Create "else" block, fill it in the next iteration
467 BasicBlock
*NewIfBlock
= CondBlock
->splitBasicBlock(InsertPt
, "else");
468 Builder
.SetInsertPoint(InsertPt
);
469 Instruction
*OldBr
= IfBlock
->getTerminator();
470 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
471 OldBr
->eraseFromParent();
472 BasicBlock
*PrevIfBlock
= IfBlock
;
473 IfBlock
= NewIfBlock
;
475 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
476 Phi
->addIncoming(NewVResult
, CondBlock
);
477 Phi
->addIncoming(VResult
, PrevIfBlock
);
481 CI
->replaceAllUsesWith(VResult
);
482 CI
->eraseFromParent();
487 // Translate a masked scatter intrinsic, like
488 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
490 // to a chain of basic blocks, that stores element one-by-one if
491 // the appropriate mask bit is set.
493 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
494 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
495 // br i1 %Mask0, label %cond.store, label %else
498 // %Elt0 = extractelement <16 x i32> %Src, i32 0
499 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
500 // store i32 %Elt0, i32* %Ptr0, align 4
504 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
505 // br i1 %Mask1, label %cond.store1, label %else2
508 // %Elt1 = extractelement <16 x i32> %Src, i32 1
509 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
510 // store i32 %Elt1, i32* %Ptr1, align 4
513 static void scalarizeMaskedScatter(CallInst
*CI
, bool &ModifiedDT
) {
514 Value
*Src
= CI
->getArgOperand(0);
515 Value
*Ptrs
= CI
->getArgOperand(1);
516 Value
*Alignment
= CI
->getArgOperand(2);
517 Value
*Mask
= CI
->getArgOperand(3);
519 assert(isa
<VectorType
>(Src
->getType()) &&
520 "Unexpected data type in masked scatter intrinsic");
521 assert(isa
<VectorType
>(Ptrs
->getType()) &&
522 isa
<PointerType
>(Ptrs
->getType()->getVectorElementType()) &&
523 "Vector of pointers is expected in masked scatter intrinsic");
525 IRBuilder
<> Builder(CI
->getContext());
526 Instruction
*InsertPt
= CI
;
527 BasicBlock
*IfBlock
= CI
->getParent();
528 Builder
.SetInsertPoint(InsertPt
);
529 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
531 unsigned AlignVal
= cast
<ConstantInt
>(Alignment
)->getZExtValue();
532 unsigned VectorWidth
= Src
->getType()->getVectorNumElements();
534 // Shorten the way if the mask is a vector of constants.
535 if (isConstantIntVector(Mask
)) {
536 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
537 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
540 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
541 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
542 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
544 CI
->eraseFromParent();
548 // If the mask is not v1i1, use scalar bit test operations. This generates
549 // better results on X86 at least.
551 if (VectorWidth
!= 1) {
552 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
553 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
556 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
557 // Fill the "else" block, created in the previous iteration
559 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
560 // %cond = icmp ne i16 %mask_1, 0
561 // br i1 %Mask1, label %cond.store, label %else
564 if (VectorWidth
!= 1) {
565 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
566 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
567 Builder
.getIntN(VectorWidth
, 0));
569 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
572 // Create "cond" block
574 // %Elt1 = extractelement <16 x i32> %Src, i32 1
575 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
576 // %store i32 %Elt1, i32* %Ptr1
578 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
, "cond.store");
579 Builder
.SetInsertPoint(InsertPt
);
581 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
582 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
583 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
585 // Create "else" block, fill it in the next iteration
586 BasicBlock
*NewIfBlock
= CondBlock
->splitBasicBlock(InsertPt
, "else");
587 Builder
.SetInsertPoint(InsertPt
);
588 Instruction
*OldBr
= IfBlock
->getTerminator();
589 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
590 OldBr
->eraseFromParent();
591 IfBlock
= NewIfBlock
;
593 CI
->eraseFromParent();
598 static void scalarizeMaskedExpandLoad(CallInst
*CI
, bool &ModifiedDT
) {
599 Value
*Ptr
= CI
->getArgOperand(0);
600 Value
*Mask
= CI
->getArgOperand(1);
601 Value
*PassThru
= CI
->getArgOperand(2);
603 VectorType
*VecType
= cast
<VectorType
>(CI
->getType());
605 Type
*EltTy
= VecType
->getElementType();
607 IRBuilder
<> Builder(CI
->getContext());
608 Instruction
*InsertPt
= CI
;
609 BasicBlock
*IfBlock
= CI
->getParent();
611 Builder
.SetInsertPoint(InsertPt
);
612 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
614 unsigned VectorWidth
= VecType
->getNumElements();
617 Value
*VResult
= PassThru
;
619 // Shorten the way if the mask is a vector of constants.
620 if (isConstantIntVector(Mask
)) {
621 unsigned MemIndex
= 0;
622 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
623 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
625 Value
*NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
627 Builder
.CreateAlignedLoad(EltTy
, NewPtr
, 1, "Load" + Twine(Idx
));
629 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
632 CI
->replaceAllUsesWith(VResult
);
633 CI
->eraseFromParent();
637 // If the mask is not v1i1, use scalar bit test operations. This generates
638 // better results on X86 at least.
640 if (VectorWidth
!= 1) {
641 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
642 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
645 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
646 // Fill the "else" block, created in the previous iteration
648 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
649 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
650 // br i1 %mask_1, label %cond.load, label %else
654 if (VectorWidth
!= 1) {
655 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
656 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
657 Builder
.getIntN(VectorWidth
, 0));
659 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
662 // Create "cond" block
664 // %EltAddr = getelementptr i32* %1, i32 0
665 // %Elt = load i32* %EltAddr
666 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
668 BasicBlock
*CondBlock
= IfBlock
->splitBasicBlock(InsertPt
->getIterator(),
670 Builder
.SetInsertPoint(InsertPt
);
672 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Ptr
, 1);
673 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
675 // Move the pointer if there are more blocks to come.
677 if ((Idx
+ 1) != VectorWidth
)
678 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
680 // Create "else" block, fill it in the next iteration
681 BasicBlock
*NewIfBlock
=
682 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
683 Builder
.SetInsertPoint(InsertPt
);
684 Instruction
*OldBr
= IfBlock
->getTerminator();
685 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
686 OldBr
->eraseFromParent();
687 BasicBlock
*PrevIfBlock
= IfBlock
;
688 IfBlock
= NewIfBlock
;
690 // Create the phi to join the new and previous value.
691 PHINode
*ResultPhi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
692 ResultPhi
->addIncoming(NewVResult
, CondBlock
);
693 ResultPhi
->addIncoming(VResult
, PrevIfBlock
);
696 // Add a PHI for the pointer if this isn't the last iteration.
697 if ((Idx
+ 1) != VectorWidth
) {
698 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
699 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
700 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
705 CI
->replaceAllUsesWith(VResult
);
706 CI
->eraseFromParent();
711 static void scalarizeMaskedCompressStore(CallInst
*CI
, bool &ModifiedDT
) {
712 Value
*Src
= CI
->getArgOperand(0);
713 Value
*Ptr
= CI
->getArgOperand(1);
714 Value
*Mask
= CI
->getArgOperand(2);
716 VectorType
*VecType
= cast
<VectorType
>(Src
->getType());
718 IRBuilder
<> Builder(CI
->getContext());
719 Instruction
*InsertPt
= CI
;
720 BasicBlock
*IfBlock
= CI
->getParent();
722 Builder
.SetInsertPoint(InsertPt
);
723 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
725 Type
*EltTy
= VecType
->getVectorElementType();
727 unsigned VectorWidth
= VecType
->getNumElements();
729 // Shorten the way if the mask is a vector of constants.
730 if (isConstantIntVector(Mask
)) {
731 unsigned MemIndex
= 0;
732 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
733 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
736 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
737 Value
*NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
738 Builder
.CreateAlignedStore(OneElt
, NewPtr
, 1);
741 CI
->eraseFromParent();
745 // If the mask is not v1i1, use scalar bit test operations. This generates
746 // better results on X86 at least.
748 if (VectorWidth
!= 1) {
749 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
750 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
753 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
754 // Fill the "else" block, created in the previous iteration
756 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
757 // br i1 %mask_1, label %cond.store, label %else
760 if (VectorWidth
!= 1) {
761 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(VectorWidth
, Idx
));
762 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
763 Builder
.getIntN(VectorWidth
, 0));
765 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
768 // Create "cond" block
770 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
771 // %EltAddr = getelementptr i32* %1, i32 0
772 // %store i32 %OneElt, i32* %EltAddr
774 BasicBlock
*CondBlock
=
775 IfBlock
->splitBasicBlock(InsertPt
->getIterator(), "cond.store");
776 Builder
.SetInsertPoint(InsertPt
);
778 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
779 Builder
.CreateAlignedStore(OneElt
, Ptr
, 1);
781 // Move the pointer if there are more blocks to come.
783 if ((Idx
+ 1) != VectorWidth
)
784 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
786 // Create "else" block, fill it in the next iteration
787 BasicBlock
*NewIfBlock
=
788 CondBlock
->splitBasicBlock(InsertPt
->getIterator(), "else");
789 Builder
.SetInsertPoint(InsertPt
);
790 Instruction
*OldBr
= IfBlock
->getTerminator();
791 BranchInst::Create(CondBlock
, NewIfBlock
, Predicate
, OldBr
);
792 OldBr
->eraseFromParent();
793 BasicBlock
*PrevIfBlock
= IfBlock
;
794 IfBlock
= NewIfBlock
;
796 // Add a PHI for the pointer if this isn't the last iteration.
797 if ((Idx
+ 1) != VectorWidth
) {
798 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
799 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
800 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
804 CI
->eraseFromParent();
809 bool ScalarizeMaskedMemIntrin::runOnFunction(Function
&F
) {
810 bool EverMadeChange
= false;
812 TTI
= &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
814 bool MadeChange
= true;
817 for (Function::iterator I
= F
.begin(); I
!= F
.end();) {
818 BasicBlock
*BB
= &*I
++;
819 bool ModifiedDTOnIteration
= false;
820 MadeChange
|= optimizeBlock(*BB
, ModifiedDTOnIteration
);
822 // Restart BB iteration if the dominator tree of the Function was changed
823 if (ModifiedDTOnIteration
)
827 EverMadeChange
|= MadeChange
;
830 return EverMadeChange
;
833 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
) {
834 bool MadeChange
= false;
836 BasicBlock::iterator CurInstIterator
= BB
.begin();
837 while (CurInstIterator
!= BB
.end()) {
838 if (CallInst
*CI
= dyn_cast
<CallInst
>(&*CurInstIterator
++))
839 MadeChange
|= optimizeCallInst(CI
, ModifiedDT
);
847 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst
*CI
,
849 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
);
851 switch (II
->getIntrinsicID()) {
854 case Intrinsic::masked_load
: {
855 // Scalarize unsupported vector masked load
857 cast
<ConstantInt
>(CI
->getArgOperand(1))->getZExtValue();
858 if (TTI
->isLegalMaskedLoad(CI
->getType(), MaybeAlign(Alignment
)))
860 scalarizeMaskedLoad(CI
, ModifiedDT
);
863 case Intrinsic::masked_store
: {
865 cast
<ConstantInt
>(CI
->getArgOperand(2))->getZExtValue();
866 if (TTI
->isLegalMaskedStore(CI
->getArgOperand(0)->getType(),
867 MaybeAlign(Alignment
)))
869 scalarizeMaskedStore(CI
, ModifiedDT
);
872 case Intrinsic::masked_gather
:
873 if (TTI
->isLegalMaskedGather(CI
->getType()))
875 scalarizeMaskedGather(CI
, ModifiedDT
);
877 case Intrinsic::masked_scatter
:
878 if (TTI
->isLegalMaskedScatter(CI
->getArgOperand(0)->getType()))
880 scalarizeMaskedScatter(CI
, ModifiedDT
);
882 case Intrinsic::masked_expandload
:
883 if (TTI
->isLegalMaskedExpandLoad(CI
->getType()))
885 scalarizeMaskedExpandLoad(CI
, ModifiedDT
);
887 case Intrinsic::masked_compressstore
:
888 if (TTI
->isLegalMaskedCompressStore(CI
->getArgOperand(0)->getType()))
890 scalarizeMaskedCompressStore(CI
, ModifiedDT
);