1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
48 class ScalarizeMaskedMemIntrinLegacyPass
: public FunctionPass
{
50 static char ID
; // Pass identification, replacement for typeid
52 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID
) {
53 initializeScalarizeMaskedMemIntrinLegacyPassPass(
54 *PassRegistry::getPassRegistry());
57 bool runOnFunction(Function
&F
) override
;
59 StringRef
getPassName() const override
{
60 return "Scalarize Masked Memory Intrinsics";
63 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
64 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
65 AU
.addPreserved
<DominatorTreeWrapperPass
>();
69 } // end anonymous namespace
71 static bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
,
72 const TargetTransformInfo
&TTI
, const DataLayout
&DL
,
74 static bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
,
75 const TargetTransformInfo
&TTI
,
76 const DataLayout
&DL
, DomTreeUpdater
*DTU
);
78 char ScalarizeMaskedMemIntrinLegacyPass::ID
= 0;
80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass
, DEBUG_TYPE
,
81 "Scalarize unsupported masked memory intrinsics", false,
83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass
, DEBUG_TYPE
,
86 "Scalarize unsupported masked memory intrinsics", false,
89 FunctionPass
*llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90 return new ScalarizeMaskedMemIntrinLegacyPass();
93 static bool isConstantIntVector(Value
*Mask
) {
94 Constant
*C
= dyn_cast
<Constant
>(Mask
);
98 unsigned NumElts
= cast
<FixedVectorType
>(Mask
->getType())->getNumElements();
99 for (unsigned i
= 0; i
!= NumElts
; ++i
) {
100 Constant
*CElt
= C
->getAggregateElement(i
);
101 if (!CElt
|| !isa
<ConstantInt
>(CElt
))
108 static unsigned adjustForEndian(const DataLayout
&DL
, unsigned VectorWidth
,
110 return DL
.isBigEndian() ? VectorWidth
- 1 - Idx
: Idx
;
113 // Translate a masked load intrinsic like
114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115 // <16 x i1> %mask, <16 x i32> %passthru)
116 // to a chain of basic blocks, with loading element one-by-one if
117 // the appropriate mask bit is set
119 // %1 = bitcast i8* %addr to i32*
120 // %2 = extractelement <16 x i1> %mask, i32 0
121 // br i1 %2, label %cond.load, label %else
123 // cond.load: ; preds = %0
124 // %3 = getelementptr i32* %1, i32 0
126 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
129 // else: ; preds = %0, %cond.load
130 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
131 // %6 = extractelement <16 x i1> %mask, i32 1
132 // br i1 %6, label %cond.load1, label %else2
134 // cond.load1: ; preds = %else
135 // %7 = getelementptr i32* %1, i32 1
137 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
140 // else2: ; preds = %else, %cond.load1
141 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142 // %10 = extractelement <16 x i1> %mask, i32 2
143 // br i1 %10, label %cond.load4, label %else5
145 static void scalarizeMaskedLoad(const DataLayout
&DL
, CallInst
*CI
,
146 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
147 Value
*Ptr
= CI
->getArgOperand(0);
148 Value
*Alignment
= CI
->getArgOperand(1);
149 Value
*Mask
= CI
->getArgOperand(2);
150 Value
*Src0
= CI
->getArgOperand(3);
152 const Align AlignVal
= cast
<ConstantInt
>(Alignment
)->getAlignValue();
153 VectorType
*VecType
= cast
<FixedVectorType
>(CI
->getType());
155 Type
*EltTy
= VecType
->getElementType();
157 IRBuilder
<> Builder(CI
->getContext());
158 Instruction
*InsertPt
= CI
;
159 BasicBlock
*IfBlock
= CI
->getParent();
161 Builder
.SetInsertPoint(InsertPt
);
162 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
164 // Short-cut if the mask is all-true.
165 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
166 Value
*NewI
= Builder
.CreateAlignedLoad(VecType
, Ptr
, AlignVal
);
167 CI
->replaceAllUsesWith(NewI
);
168 CI
->eraseFromParent();
172 // Adjust alignment for the scalar instruction.
173 const Align AdjustedAlignVal
=
174 commonAlignment(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
175 // Bitcast %addr from i8* to EltTy*
177 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
178 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
179 unsigned VectorWidth
= cast
<FixedVectorType
>(VecType
)->getNumElements();
182 Value
*VResult
= Src0
;
184 if (isConstantIntVector(Mask
)) {
185 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
186 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
188 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
189 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AdjustedAlignVal
);
190 VResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
192 CI
->replaceAllUsesWith(VResult
);
193 CI
->eraseFromParent();
197 // If the mask is not v1i1, use scalar bit test operations. This generates
198 // better results on X86 at least.
200 if (VectorWidth
!= 1) {
201 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
202 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
205 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
206 // Fill the "else" block, created in the previous iteration
208 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
209 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
210 // %cond = icmp ne i16 %mask_1, 0
211 // br i1 %mask_1, label %cond.load, label %else
214 if (VectorWidth
!= 1) {
215 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
216 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
217 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
218 Builder
.getIntN(VectorWidth
, 0));
220 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
223 // Create "cond" block
225 // %EltAddr = getelementptr i32* %1, i32 0
226 // %Elt = load i32* %EltAddr
227 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
229 Instruction
*ThenTerm
=
230 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
231 /*BranchWeights=*/nullptr, DTU
);
233 BasicBlock
*CondBlock
= ThenTerm
->getParent();
234 CondBlock
->setName("cond.load");
236 Builder
.SetInsertPoint(CondBlock
->getTerminator());
237 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
238 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AdjustedAlignVal
);
239 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
241 // Create "else" block, fill it in the next iteration
242 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
243 NewIfBlock
->setName("else");
244 BasicBlock
*PrevIfBlock
= IfBlock
;
245 IfBlock
= NewIfBlock
;
247 // Create the phi to join the new and previous value.
248 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
249 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
250 Phi
->addIncoming(NewVResult
, CondBlock
);
251 Phi
->addIncoming(VResult
, PrevIfBlock
);
255 CI
->replaceAllUsesWith(VResult
);
256 CI
->eraseFromParent();
261 // Translate a masked store intrinsic, like
262 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
264 // to a chain of basic blocks, that stores element one-by-one if
265 // the appropriate mask bit is set
267 // %1 = bitcast i8* %addr to i32*
268 // %2 = extractelement <16 x i1> %mask, i32 0
269 // br i1 %2, label %cond.store, label %else
271 // cond.store: ; preds = %0
272 // %3 = extractelement <16 x i32> %val, i32 0
273 // %4 = getelementptr i32* %1, i32 0
274 // store i32 %3, i32* %4
277 // else: ; preds = %0, %cond.store
278 // %5 = extractelement <16 x i1> %mask, i32 1
279 // br i1 %5, label %cond.store1, label %else2
281 // cond.store1: ; preds = %else
282 // %6 = extractelement <16 x i32> %val, i32 1
283 // %7 = getelementptr i32* %1, i32 1
284 // store i32 %6, i32* %7
287 static void scalarizeMaskedStore(const DataLayout
&DL
, CallInst
*CI
,
288 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
289 Value
*Src
= CI
->getArgOperand(0);
290 Value
*Ptr
= CI
->getArgOperand(1);
291 Value
*Alignment
= CI
->getArgOperand(2);
292 Value
*Mask
= CI
->getArgOperand(3);
294 const Align AlignVal
= cast
<ConstantInt
>(Alignment
)->getAlignValue();
295 auto *VecType
= cast
<VectorType
>(Src
->getType());
297 Type
*EltTy
= VecType
->getElementType();
299 IRBuilder
<> Builder(CI
->getContext());
300 Instruction
*InsertPt
= CI
;
301 Builder
.SetInsertPoint(InsertPt
);
302 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
304 // Short-cut if the mask is all-true.
305 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
306 Builder
.CreateAlignedStore(Src
, Ptr
, AlignVal
);
307 CI
->eraseFromParent();
311 // Adjust alignment for the scalar instruction.
312 const Align AdjustedAlignVal
=
313 commonAlignment(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
314 // Bitcast %addr from i8* to EltTy*
316 EltTy
->getPointerTo(Ptr
->getType()->getPointerAddressSpace());
317 Value
*FirstEltPtr
= Builder
.CreateBitCast(Ptr
, NewPtrType
);
318 unsigned VectorWidth
= cast
<FixedVectorType
>(VecType
)->getNumElements();
320 if (isConstantIntVector(Mask
)) {
321 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
322 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
324 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
325 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
326 Builder
.CreateAlignedStore(OneElt
, Gep
, AdjustedAlignVal
);
328 CI
->eraseFromParent();
332 // If the mask is not v1i1, use scalar bit test operations. This generates
333 // better results on X86 at least.
335 if (VectorWidth
!= 1) {
336 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
337 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
340 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
341 // Fill the "else" block, created in the previous iteration
343 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
344 // %cond = icmp ne i16 %mask_1, 0
345 // br i1 %mask_1, label %cond.store, label %else
348 if (VectorWidth
!= 1) {
349 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
350 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
351 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
352 Builder
.getIntN(VectorWidth
, 0));
354 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
357 // Create "cond" block
359 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
360 // %EltAddr = getelementptr i32* %1, i32 0
361 // %store i32 %OneElt, i32* %EltAddr
363 Instruction
*ThenTerm
=
364 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
365 /*BranchWeights=*/nullptr, DTU
);
367 BasicBlock
*CondBlock
= ThenTerm
->getParent();
368 CondBlock
->setName("cond.store");
370 Builder
.SetInsertPoint(CondBlock
->getTerminator());
371 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
372 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, FirstEltPtr
, Idx
);
373 Builder
.CreateAlignedStore(OneElt
, Gep
, AdjustedAlignVal
);
375 // Create "else" block, fill it in the next iteration
376 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
377 NewIfBlock
->setName("else");
379 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
381 CI
->eraseFromParent();
386 // Translate a masked gather intrinsic like
387 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
388 // <16 x i1> %Mask, <16 x i32> %Src)
389 // to a chain of basic blocks, with loading element one-by-one if
390 // the appropriate mask bit is set
392 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
393 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
394 // br i1 %Mask0, label %cond.load, label %else
397 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
398 // %Load0 = load i32, i32* %Ptr0, align 4
399 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
403 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
404 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
405 // br i1 %Mask1, label %cond.load1, label %else2
408 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
409 // %Load1 = load i32, i32* %Ptr1, align 4
410 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
413 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
414 // ret <16 x i32> %Result
415 static void scalarizeMaskedGather(const DataLayout
&DL
, CallInst
*CI
,
416 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
417 Value
*Ptrs
= CI
->getArgOperand(0);
418 Value
*Alignment
= CI
->getArgOperand(1);
419 Value
*Mask
= CI
->getArgOperand(2);
420 Value
*Src0
= CI
->getArgOperand(3);
422 auto *VecType
= cast
<FixedVectorType
>(CI
->getType());
423 Type
*EltTy
= VecType
->getElementType();
425 IRBuilder
<> Builder(CI
->getContext());
426 Instruction
*InsertPt
= CI
;
427 BasicBlock
*IfBlock
= CI
->getParent();
428 Builder
.SetInsertPoint(InsertPt
);
429 MaybeAlign AlignVal
= cast
<ConstantInt
>(Alignment
)->getMaybeAlignValue();
431 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
434 Value
*VResult
= Src0
;
435 unsigned VectorWidth
= VecType
->getNumElements();
437 // Shorten the way if the mask is a vector of constants.
438 if (isConstantIntVector(Mask
)) {
439 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
440 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
442 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
444 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
446 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
448 CI
->replaceAllUsesWith(VResult
);
449 CI
->eraseFromParent();
453 // If the mask is not v1i1, use scalar bit test operations. This generates
454 // better results on X86 at least.
456 if (VectorWidth
!= 1) {
457 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
458 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
461 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
462 // Fill the "else" block, created in the previous iteration
464 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
465 // %cond = icmp ne i16 %mask_1, 0
466 // br i1 %Mask1, label %cond.load, label %else
470 if (VectorWidth
!= 1) {
471 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
472 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
473 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
474 Builder
.getIntN(VectorWidth
, 0));
476 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
479 // Create "cond" block
481 // %EltAddr = getelementptr i32* %1, i32 0
482 // %Elt = load i32* %EltAddr
483 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
485 Instruction
*ThenTerm
=
486 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
487 /*BranchWeights=*/nullptr, DTU
);
489 BasicBlock
*CondBlock
= ThenTerm
->getParent();
490 CondBlock
->setName("cond.load");
492 Builder
.SetInsertPoint(CondBlock
->getTerminator());
493 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
495 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
497 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
499 // Create "else" block, fill it in the next iteration
500 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
501 NewIfBlock
->setName("else");
502 BasicBlock
*PrevIfBlock
= IfBlock
;
503 IfBlock
= NewIfBlock
;
505 // Create the phi to join the new and previous value.
506 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
507 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
508 Phi
->addIncoming(NewVResult
, CondBlock
);
509 Phi
->addIncoming(VResult
, PrevIfBlock
);
513 CI
->replaceAllUsesWith(VResult
);
514 CI
->eraseFromParent();
519 // Translate a masked scatter intrinsic, like
520 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
522 // to a chain of basic blocks, that stores element one-by-one if
523 // the appropriate mask bit is set.
525 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
526 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
527 // br i1 %Mask0, label %cond.store, label %else
530 // %Elt0 = extractelement <16 x i32> %Src, i32 0
531 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
532 // store i32 %Elt0, i32* %Ptr0, align 4
536 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
537 // br i1 %Mask1, label %cond.store1, label %else2
540 // %Elt1 = extractelement <16 x i32> %Src, i32 1
541 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
542 // store i32 %Elt1, i32* %Ptr1, align 4
545 static void scalarizeMaskedScatter(const DataLayout
&DL
, CallInst
*CI
,
546 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
547 Value
*Src
= CI
->getArgOperand(0);
548 Value
*Ptrs
= CI
->getArgOperand(1);
549 Value
*Alignment
= CI
->getArgOperand(2);
550 Value
*Mask
= CI
->getArgOperand(3);
552 auto *SrcFVTy
= cast
<FixedVectorType
>(Src
->getType());
555 isa
<VectorType
>(Ptrs
->getType()) &&
556 isa
<PointerType
>(cast
<VectorType
>(Ptrs
->getType())->getElementType()) &&
557 "Vector of pointers is expected in masked scatter intrinsic");
559 IRBuilder
<> Builder(CI
->getContext());
560 Instruction
*InsertPt
= CI
;
561 Builder
.SetInsertPoint(InsertPt
);
562 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
564 MaybeAlign AlignVal
= cast
<ConstantInt
>(Alignment
)->getMaybeAlignValue();
565 unsigned VectorWidth
= SrcFVTy
->getNumElements();
567 // Shorten the way if the mask is a vector of constants.
568 if (isConstantIntVector(Mask
)) {
569 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
570 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
573 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
574 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
575 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
577 CI
->eraseFromParent();
581 // If the mask is not v1i1, use scalar bit test operations. This generates
582 // better results on X86 at least.
584 if (VectorWidth
!= 1) {
585 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
586 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
589 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
590 // Fill the "else" block, created in the previous iteration
592 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
593 // %cond = icmp ne i16 %mask_1, 0
594 // br i1 %Mask1, label %cond.store, label %else
597 if (VectorWidth
!= 1) {
598 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
599 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
600 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
601 Builder
.getIntN(VectorWidth
, 0));
603 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
606 // Create "cond" block
608 // %Elt1 = extractelement <16 x i32> %Src, i32 1
609 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
610 // %store i32 %Elt1, i32* %Ptr1
612 Instruction
*ThenTerm
=
613 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
614 /*BranchWeights=*/nullptr, DTU
);
616 BasicBlock
*CondBlock
= ThenTerm
->getParent();
617 CondBlock
->setName("cond.store");
619 Builder
.SetInsertPoint(CondBlock
->getTerminator());
620 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
621 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
622 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
624 // Create "else" block, fill it in the next iteration
625 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
626 NewIfBlock
->setName("else");
628 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
630 CI
->eraseFromParent();
635 static void scalarizeMaskedExpandLoad(const DataLayout
&DL
, CallInst
*CI
,
636 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
637 Value
*Ptr
= CI
->getArgOperand(0);
638 Value
*Mask
= CI
->getArgOperand(1);
639 Value
*PassThru
= CI
->getArgOperand(2);
641 auto *VecType
= cast
<FixedVectorType
>(CI
->getType());
643 Type
*EltTy
= VecType
->getElementType();
645 IRBuilder
<> Builder(CI
->getContext());
646 Instruction
*InsertPt
= CI
;
647 BasicBlock
*IfBlock
= CI
->getParent();
649 Builder
.SetInsertPoint(InsertPt
);
650 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
652 unsigned VectorWidth
= VecType
->getNumElements();
655 Value
*VResult
= PassThru
;
657 // Shorten the way if the mask is a vector of constants.
658 // Create a build_vector pattern, with loads/undefs as necessary and then
659 // shuffle blend with the pass through value.
660 if (isConstantIntVector(Mask
)) {
661 unsigned MemIndex
= 0;
662 VResult
= UndefValue::get(VecType
);
663 SmallVector
<int, 16> ShuffleMask(VectorWidth
, UndefMaskElem
);
664 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
666 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue()) {
667 InsertElt
= UndefValue::get(EltTy
);
668 ShuffleMask
[Idx
] = Idx
+ VectorWidth
;
671 Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
672 InsertElt
= Builder
.CreateAlignedLoad(EltTy
, NewPtr
, Align(1),
673 "Load" + Twine(Idx
));
674 ShuffleMask
[Idx
] = Idx
;
677 VResult
= Builder
.CreateInsertElement(VResult
, InsertElt
, Idx
,
680 VResult
= Builder
.CreateShuffleVector(VResult
, PassThru
, ShuffleMask
);
681 CI
->replaceAllUsesWith(VResult
);
682 CI
->eraseFromParent();
686 // If the mask is not v1i1, use scalar bit test operations. This generates
687 // better results on X86 at least.
689 if (VectorWidth
!= 1) {
690 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
691 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
694 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
695 // Fill the "else" block, created in the previous iteration
697 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
698 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
699 // br i1 %mask_1, label %cond.load, label %else
703 if (VectorWidth
!= 1) {
704 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
705 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
706 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
707 Builder
.getIntN(VectorWidth
, 0));
709 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
712 // Create "cond" block
714 // %EltAddr = getelementptr i32* %1, i32 0
715 // %Elt = load i32* %EltAddr
716 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
718 Instruction
*ThenTerm
=
719 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
720 /*BranchWeights=*/nullptr, DTU
);
722 BasicBlock
*CondBlock
= ThenTerm
->getParent();
723 CondBlock
->setName("cond.load");
725 Builder
.SetInsertPoint(CondBlock
->getTerminator());
726 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Ptr
, Align(1));
727 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
729 // Move the pointer if there are more blocks to come.
731 if ((Idx
+ 1) != VectorWidth
)
732 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
734 // Create "else" block, fill it in the next iteration
735 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
736 NewIfBlock
->setName("else");
737 BasicBlock
*PrevIfBlock
= IfBlock
;
738 IfBlock
= NewIfBlock
;
740 // Create the phi to join the new and previous value.
741 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
742 PHINode
*ResultPhi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
743 ResultPhi
->addIncoming(NewVResult
, CondBlock
);
744 ResultPhi
->addIncoming(VResult
, PrevIfBlock
);
747 // Add a PHI for the pointer if this isn't the last iteration.
748 if ((Idx
+ 1) != VectorWidth
) {
749 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
750 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
751 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
756 CI
->replaceAllUsesWith(VResult
);
757 CI
->eraseFromParent();
762 static void scalarizeMaskedCompressStore(const DataLayout
&DL
, CallInst
*CI
,
765 Value
*Src
= CI
->getArgOperand(0);
766 Value
*Ptr
= CI
->getArgOperand(1);
767 Value
*Mask
= CI
->getArgOperand(2);
769 auto *VecType
= cast
<FixedVectorType
>(Src
->getType());
771 IRBuilder
<> Builder(CI
->getContext());
772 Instruction
*InsertPt
= CI
;
773 BasicBlock
*IfBlock
= CI
->getParent();
775 Builder
.SetInsertPoint(InsertPt
);
776 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
778 Type
*EltTy
= VecType
->getElementType();
780 unsigned VectorWidth
= VecType
->getNumElements();
782 // Shorten the way if the mask is a vector of constants.
783 if (isConstantIntVector(Mask
)) {
784 unsigned MemIndex
= 0;
785 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
786 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
789 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
790 Value
*NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
791 Builder
.CreateAlignedStore(OneElt
, NewPtr
, Align(1));
794 CI
->eraseFromParent();
798 // If the mask is not v1i1, use scalar bit test operations. This generates
799 // better results on X86 at least.
801 if (VectorWidth
!= 1) {
802 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
803 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
806 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
807 // Fill the "else" block, created in the previous iteration
809 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
810 // br i1 %mask_1, label %cond.store, label %else
813 if (VectorWidth
!= 1) {
814 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
815 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
816 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
817 Builder
.getIntN(VectorWidth
, 0));
819 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
822 // Create "cond" block
824 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
825 // %EltAddr = getelementptr i32* %1, i32 0
826 // %store i32 %OneElt, i32* %EltAddr
828 Instruction
*ThenTerm
=
829 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
830 /*BranchWeights=*/nullptr, DTU
);
832 BasicBlock
*CondBlock
= ThenTerm
->getParent();
833 CondBlock
->setName("cond.store");
835 Builder
.SetInsertPoint(CondBlock
->getTerminator());
836 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
837 Builder
.CreateAlignedStore(OneElt
, Ptr
, Align(1));
839 // Move the pointer if there are more blocks to come.
841 if ((Idx
+ 1) != VectorWidth
)
842 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
844 // Create "else" block, fill it in the next iteration
845 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
846 NewIfBlock
->setName("else");
847 BasicBlock
*PrevIfBlock
= IfBlock
;
848 IfBlock
= NewIfBlock
;
850 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
852 // Add a PHI for the pointer if this isn't the last iteration.
853 if ((Idx
+ 1) != VectorWidth
) {
854 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
855 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
856 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
860 CI
->eraseFromParent();
865 static bool runImpl(Function
&F
, const TargetTransformInfo
&TTI
,
867 Optional
<DomTreeUpdater
> DTU
;
869 DTU
.emplace(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
871 bool EverMadeChange
= false;
872 bool MadeChange
= true;
873 auto &DL
= F
.getParent()->getDataLayout();
876 for (Function::iterator I
= F
.begin(); I
!= F
.end();) {
877 BasicBlock
*BB
= &*I
++;
878 bool ModifiedDTOnIteration
= false;
879 MadeChange
|= optimizeBlock(*BB
, ModifiedDTOnIteration
, TTI
, DL
,
880 DTU
.hasValue() ? DTU
.getPointer() : nullptr);
883 // Restart BB iteration if the dominator tree of the Function was changed
884 if (ModifiedDTOnIteration
)
888 EverMadeChange
|= MadeChange
;
890 return EverMadeChange
;
893 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function
&F
) {
894 auto &TTI
= getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
895 DominatorTree
*DT
= nullptr;
896 if (auto *DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>())
897 DT
= &DTWP
->getDomTree();
898 return runImpl(F
, TTI
, DT
);
902 ScalarizeMaskedMemIntrinPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
903 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
904 auto *DT
= AM
.getCachedResult
<DominatorTreeAnalysis
>(F
);
905 if (!runImpl(F
, TTI
, DT
))
906 return PreservedAnalyses::all();
907 PreservedAnalyses PA
;
908 PA
.preserve
<TargetIRAnalysis
>();
909 PA
.preserve
<DominatorTreeAnalysis
>();
913 static bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
,
914 const TargetTransformInfo
&TTI
, const DataLayout
&DL
,
915 DomTreeUpdater
*DTU
) {
916 bool MadeChange
= false;
918 BasicBlock::iterator CurInstIterator
= BB
.begin();
919 while (CurInstIterator
!= BB
.end()) {
920 if (CallInst
*CI
= dyn_cast
<CallInst
>(&*CurInstIterator
++))
921 MadeChange
|= optimizeCallInst(CI
, ModifiedDT
, TTI
, DL
, DTU
);
929 static bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
,
930 const TargetTransformInfo
&TTI
,
931 const DataLayout
&DL
, DomTreeUpdater
*DTU
) {
932 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
);
934 // The scalarization code below does not work for scalable vectors.
935 if (isa
<ScalableVectorType
>(II
->getType()) ||
936 any_of(II
->arg_operands(),
937 [](Value
*V
) { return isa
<ScalableVectorType
>(V
->getType()); }))
940 switch (II
->getIntrinsicID()) {
943 case Intrinsic::masked_load
:
944 // Scalarize unsupported vector masked load
945 if (TTI
.isLegalMaskedLoad(
947 cast
<ConstantInt
>(CI
->getArgOperand(1))->getAlignValue()))
949 scalarizeMaskedLoad(DL
, CI
, DTU
, ModifiedDT
);
951 case Intrinsic::masked_store
:
952 if (TTI
.isLegalMaskedStore(
953 CI
->getArgOperand(0)->getType(),
954 cast
<ConstantInt
>(CI
->getArgOperand(2))->getAlignValue()))
956 scalarizeMaskedStore(DL
, CI
, DTU
, ModifiedDT
);
958 case Intrinsic::masked_gather
: {
960 cast
<ConstantInt
>(CI
->getArgOperand(1))->getMaybeAlignValue();
961 Type
*LoadTy
= CI
->getType();
962 Align Alignment
= DL
.getValueOrABITypeAlignment(MA
,
963 LoadTy
->getScalarType());
964 if (TTI
.isLegalMaskedGather(LoadTy
, Alignment
))
966 scalarizeMaskedGather(DL
, CI
, DTU
, ModifiedDT
);
969 case Intrinsic::masked_scatter
: {
971 cast
<ConstantInt
>(CI
->getArgOperand(2))->getMaybeAlignValue();
972 Type
*StoreTy
= CI
->getArgOperand(0)->getType();
973 Align Alignment
= DL
.getValueOrABITypeAlignment(MA
,
974 StoreTy
->getScalarType());
975 if (TTI
.isLegalMaskedScatter(StoreTy
, Alignment
))
977 scalarizeMaskedScatter(DL
, CI
, DTU
, ModifiedDT
);
980 case Intrinsic::masked_expandload
:
981 if (TTI
.isLegalMaskedExpandLoad(CI
->getType()))
983 scalarizeMaskedExpandLoad(DL
, CI
, DTU
, ModifiedDT
);
985 case Intrinsic::masked_compressstore
:
986 if (TTI
.isLegalMaskedCompressStore(CI
->getArgOperand(0)->getType()))
988 scalarizeMaskedCompressStore(DL
, CI
, DTU
, ModifiedDT
);