1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Transforms/Scalar.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
47 class ScalarizeMaskedMemIntrinLegacyPass
: public FunctionPass
{
49 static char ID
; // Pass identification, replacement for typeid
51 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID
) {
52 initializeScalarizeMaskedMemIntrinLegacyPassPass(
53 *PassRegistry::getPassRegistry());
56 bool runOnFunction(Function
&F
) override
;
58 StringRef
getPassName() const override
{
59 return "Scalarize Masked Memory Intrinsics";
62 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
63 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
64 AU
.addPreserved
<DominatorTreeWrapperPass
>();
68 } // end anonymous namespace
70 static bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
,
71 const TargetTransformInfo
&TTI
, const DataLayout
&DL
,
72 bool HasBranchDivergence
, DomTreeUpdater
*DTU
);
73 static bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
,
74 const TargetTransformInfo
&TTI
,
75 const DataLayout
&DL
, bool HasBranchDivergence
,
78 char ScalarizeMaskedMemIntrinLegacyPass::ID
= 0;
80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass
, DEBUG_TYPE
,
81 "Scalarize unsupported masked memory intrinsics", false,
83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass
, DEBUG_TYPE
,
86 "Scalarize unsupported masked memory intrinsics", false,
89 FunctionPass
*llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90 return new ScalarizeMaskedMemIntrinLegacyPass();
93 static bool isConstantIntVector(Value
*Mask
) {
94 Constant
*C
= dyn_cast
<Constant
>(Mask
);
98 unsigned NumElts
= cast
<FixedVectorType
>(Mask
->getType())->getNumElements();
99 for (unsigned i
= 0; i
!= NumElts
; ++i
) {
100 Constant
*CElt
= C
->getAggregateElement(i
);
101 if (!CElt
|| !isa
<ConstantInt
>(CElt
))
108 static unsigned adjustForEndian(const DataLayout
&DL
, unsigned VectorWidth
,
110 return DL
.isBigEndian() ? VectorWidth
- 1 - Idx
: Idx
;
113 // Translate a masked load intrinsic like
114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115 // <16 x i1> %mask, <16 x i32> %passthru)
116 // to a chain of basic blocks, with loading element one-by-one if
117 // the appropriate mask bit is set
119 // %1 = bitcast i8* %addr to i32*
120 // %2 = extractelement <16 x i1> %mask, i32 0
121 // br i1 %2, label %cond.load, label %else
123 // cond.load: ; preds = %0
124 // %3 = getelementptr i32* %1, i32 0
126 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
129 // else: ; preds = %0, %cond.load
130 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131 // %6 = extractelement <16 x i1> %mask, i32 1
132 // br i1 %6, label %cond.load1, label %else2
134 // cond.load1: ; preds = %else
135 // %7 = getelementptr i32* %1, i32 1
137 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
140 // else2: ; preds = %else, %cond.load1
141 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142 // %10 = extractelement <16 x i1> %mask, i32 2
143 // br i1 %10, label %cond.load4, label %else5
145 static void scalarizeMaskedLoad(const DataLayout
&DL
, bool HasBranchDivergence
,
146 CallInst
*CI
, DomTreeUpdater
*DTU
,
148 Value
*Ptr
= CI
->getArgOperand(0);
149 Value
*Alignment
= CI
->getArgOperand(1);
150 Value
*Mask
= CI
->getArgOperand(2);
151 Value
*Src0
= CI
->getArgOperand(3);
153 const Align AlignVal
= cast
<ConstantInt
>(Alignment
)->getAlignValue();
154 VectorType
*VecType
= cast
<FixedVectorType
>(CI
->getType());
156 Type
*EltTy
= VecType
->getElementType();
158 IRBuilder
<> Builder(CI
->getContext());
159 Instruction
*InsertPt
= CI
;
160 BasicBlock
*IfBlock
= CI
->getParent();
162 Builder
.SetInsertPoint(InsertPt
);
163 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
165 // Short-cut if the mask is all-true.
166 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
167 LoadInst
*NewI
= Builder
.CreateAlignedLoad(VecType
, Ptr
, AlignVal
);
168 NewI
->copyMetadata(*CI
);
170 CI
->replaceAllUsesWith(NewI
);
171 CI
->eraseFromParent();
175 // Adjust alignment for the scalar instruction.
176 const Align AdjustedAlignVal
=
177 commonAlignment(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
178 unsigned VectorWidth
= cast
<FixedVectorType
>(VecType
)->getNumElements();
181 Value
*VResult
= Src0
;
183 if (isConstantIntVector(Mask
)) {
184 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
185 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
187 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, Idx
);
188 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AdjustedAlignVal
);
189 VResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
191 CI
->replaceAllUsesWith(VResult
);
192 CI
->eraseFromParent();
196 // Optimize the case where the "masked load" is a predicated load - that is,
197 // where the mask is the splat of a non-constant scalar boolean. In that case,
198 // use that splated value as the guard on a conditional vector load.
199 if (isSplatValue(Mask
, /*Index=*/0)) {
200 Value
*Predicate
= Builder
.CreateExtractElement(Mask
, uint64_t(0ull),
201 Mask
->getName() + ".first");
202 Instruction
*ThenTerm
=
203 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
204 /*BranchWeights=*/nullptr, DTU
);
206 BasicBlock
*CondBlock
= ThenTerm
->getParent();
207 CondBlock
->setName("cond.load");
208 Builder
.SetInsertPoint(CondBlock
->getTerminator());
209 LoadInst
*Load
= Builder
.CreateAlignedLoad(VecType
, Ptr
, AlignVal
,
210 CI
->getName() + ".cond.load");
211 Load
->copyMetadata(*CI
);
213 BasicBlock
*PostLoad
= ThenTerm
->getSuccessor(0);
214 Builder
.SetInsertPoint(PostLoad
, PostLoad
->begin());
215 PHINode
*Phi
= Builder
.CreatePHI(VecType
, /*NumReservedValues=*/2);
216 Phi
->addIncoming(Load
, CondBlock
);
217 Phi
->addIncoming(Src0
, IfBlock
);
220 CI
->replaceAllUsesWith(Phi
);
221 CI
->eraseFromParent();
225 // If the mask is not v1i1, use scalar bit test operations. This generates
226 // better results on X86 at least. However, don't do this on GPUs and other
227 // machines with divergence, as there each i1 needs a vector register.
228 Value
*SclrMask
= nullptr;
229 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
230 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
231 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
234 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
235 // Fill the "else" block, created in the previous iteration
237 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
242 // %cond = extrectelement %mask, Idx
245 if (SclrMask
!= nullptr) {
246 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
247 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
248 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
249 Builder
.getIntN(VectorWidth
, 0));
251 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
254 // Create "cond" block
256 // %EltAddr = getelementptr i32* %1, i32 0
257 // %Elt = load i32* %EltAddr
258 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
260 Instruction
*ThenTerm
=
261 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
262 /*BranchWeights=*/nullptr, DTU
);
264 BasicBlock
*CondBlock
= ThenTerm
->getParent();
265 CondBlock
->setName("cond.load");
267 Builder
.SetInsertPoint(CondBlock
->getTerminator());
268 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, Idx
);
269 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Gep
, AdjustedAlignVal
);
270 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
272 // Create "else" block, fill it in the next iteration
273 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
274 NewIfBlock
->setName("else");
275 BasicBlock
*PrevIfBlock
= IfBlock
;
276 IfBlock
= NewIfBlock
;
278 // Create the phi to join the new and previous value.
279 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
280 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
281 Phi
->addIncoming(NewVResult
, CondBlock
);
282 Phi
->addIncoming(VResult
, PrevIfBlock
);
286 CI
->replaceAllUsesWith(VResult
);
287 CI
->eraseFromParent();
292 // Translate a masked store intrinsic, like
293 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
295 // to a chain of basic blocks, that stores element one-by-one if
296 // the appropriate mask bit is set
298 // %1 = bitcast i8* %addr to i32*
299 // %2 = extractelement <16 x i1> %mask, i32 0
300 // br i1 %2, label %cond.store, label %else
302 // cond.store: ; preds = %0
303 // %3 = extractelement <16 x i32> %val, i32 0
304 // %4 = getelementptr i32* %1, i32 0
305 // store i32 %3, i32* %4
308 // else: ; preds = %0, %cond.store
309 // %5 = extractelement <16 x i1> %mask, i32 1
310 // br i1 %5, label %cond.store1, label %else2
312 // cond.store1: ; preds = %else
313 // %6 = extractelement <16 x i32> %val, i32 1
314 // %7 = getelementptr i32* %1, i32 1
315 // store i32 %6, i32* %7
318 static void scalarizeMaskedStore(const DataLayout
&DL
, bool HasBranchDivergence
,
319 CallInst
*CI
, DomTreeUpdater
*DTU
,
321 Value
*Src
= CI
->getArgOperand(0);
322 Value
*Ptr
= CI
->getArgOperand(1);
323 Value
*Alignment
= CI
->getArgOperand(2);
324 Value
*Mask
= CI
->getArgOperand(3);
326 const Align AlignVal
= cast
<ConstantInt
>(Alignment
)->getAlignValue();
327 auto *VecType
= cast
<VectorType
>(Src
->getType());
329 Type
*EltTy
= VecType
->getElementType();
331 IRBuilder
<> Builder(CI
->getContext());
332 Instruction
*InsertPt
= CI
;
333 Builder
.SetInsertPoint(InsertPt
);
334 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
336 // Short-cut if the mask is all-true.
337 if (isa
<Constant
>(Mask
) && cast
<Constant
>(Mask
)->isAllOnesValue()) {
338 StoreInst
*Store
= Builder
.CreateAlignedStore(Src
, Ptr
, AlignVal
);
340 Store
->copyMetadata(*CI
);
341 CI
->eraseFromParent();
345 // Adjust alignment for the scalar instruction.
346 const Align AdjustedAlignVal
=
347 commonAlignment(AlignVal
, EltTy
->getPrimitiveSizeInBits() / 8);
348 unsigned VectorWidth
= cast
<FixedVectorType
>(VecType
)->getNumElements();
350 if (isConstantIntVector(Mask
)) {
351 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
352 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
354 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
355 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, Idx
);
356 Builder
.CreateAlignedStore(OneElt
, Gep
, AdjustedAlignVal
);
358 CI
->eraseFromParent();
362 // Optimize the case where the "masked store" is a predicated store - that is,
363 // when the mask is the splat of a non-constant scalar boolean. In that case,
364 // optimize to a conditional store.
365 if (isSplatValue(Mask
, /*Index=*/0)) {
366 Value
*Predicate
= Builder
.CreateExtractElement(Mask
, uint64_t(0ull),
367 Mask
->getName() + ".first");
368 Instruction
*ThenTerm
=
369 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
370 /*BranchWeights=*/nullptr, DTU
);
371 BasicBlock
*CondBlock
= ThenTerm
->getParent();
372 CondBlock
->setName("cond.store");
373 Builder
.SetInsertPoint(CondBlock
->getTerminator());
375 StoreInst
*Store
= Builder
.CreateAlignedStore(Src
, Ptr
, AlignVal
);
377 Store
->copyMetadata(*CI
);
379 CI
->eraseFromParent();
384 // If the mask is not v1i1, use scalar bit test operations. This generates
385 // better results on X86 at least. However, don't do this on GPUs or other
386 // machines with branch divergence, as there each i1 takes up a register.
387 Value
*SclrMask
= nullptr;
388 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
389 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
390 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
393 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
394 // Fill the "else" block, created in the previous iteration
396 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
397 // %cond = icmp ne i16 %mask_1, 0
398 // br i1 %mask_1, label %cond.store, label %else
401 // %cond = extrectelement %mask, Idx
404 if (SclrMask
!= nullptr) {
405 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
406 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
407 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
408 Builder
.getIntN(VectorWidth
, 0));
410 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
);
413 // Create "cond" block
415 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
416 // %EltAddr = getelementptr i32* %1, i32 0
417 // %store i32 %OneElt, i32* %EltAddr
419 Instruction
*ThenTerm
=
420 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
421 /*BranchWeights=*/nullptr, DTU
);
423 BasicBlock
*CondBlock
= ThenTerm
->getParent();
424 CondBlock
->setName("cond.store");
426 Builder
.SetInsertPoint(CondBlock
->getTerminator());
427 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
428 Value
*Gep
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, Idx
);
429 Builder
.CreateAlignedStore(OneElt
, Gep
, AdjustedAlignVal
);
431 // Create "else" block, fill it in the next iteration
432 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
433 NewIfBlock
->setName("else");
435 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
437 CI
->eraseFromParent();
442 // Translate a masked gather intrinsic like
443 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
444 // <16 x i1> %Mask, <16 x i32> %Src)
445 // to a chain of basic blocks, with loading element one-by-one if
446 // the appropriate mask bit is set
448 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
449 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
450 // br i1 %Mask0, label %cond.load, label %else
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // %Load0 = load i32, i32* %Ptr0, align 4
455 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
459 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
460 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
461 // br i1 %Mask1, label %cond.load1, label %else2
464 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
465 // %Load1 = load i32, i32* %Ptr1, align 4
466 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
469 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
470 // ret <16 x i32> %Result
471 static void scalarizeMaskedGather(const DataLayout
&DL
,
472 bool HasBranchDivergence
, CallInst
*CI
,
473 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
474 Value
*Ptrs
= CI
->getArgOperand(0);
475 Value
*Alignment
= CI
->getArgOperand(1);
476 Value
*Mask
= CI
->getArgOperand(2);
477 Value
*Src0
= CI
->getArgOperand(3);
479 auto *VecType
= cast
<FixedVectorType
>(CI
->getType());
480 Type
*EltTy
= VecType
->getElementType();
482 IRBuilder
<> Builder(CI
->getContext());
483 Instruction
*InsertPt
= CI
;
484 BasicBlock
*IfBlock
= CI
->getParent();
485 Builder
.SetInsertPoint(InsertPt
);
486 MaybeAlign AlignVal
= cast
<ConstantInt
>(Alignment
)->getMaybeAlignValue();
488 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
491 Value
*VResult
= Src0
;
492 unsigned VectorWidth
= VecType
->getNumElements();
494 // Shorten the way if the mask is a vector of constants.
495 if (isConstantIntVector(Mask
)) {
496 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
497 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
499 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
501 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
503 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
505 CI
->replaceAllUsesWith(VResult
);
506 CI
->eraseFromParent();
510 // If the mask is not v1i1, use scalar bit test operations. This generates
511 // better results on X86 at least. However, don't do this on GPUs or other
512 // machines with branch divergence, as there, each i1 takes up a register.
513 Value
*SclrMask
= nullptr;
514 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
515 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
516 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
519 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
520 // Fill the "else" block, created in the previous iteration
522 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
523 // %cond = icmp ne i16 %mask_1, 0
524 // br i1 %Mask1, label %cond.load, label %else
527 // %cond = extrectelement %mask, Idx
531 if (SclrMask
!= nullptr) {
532 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
533 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
534 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
535 Builder
.getIntN(VectorWidth
, 0));
537 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
540 // Create "cond" block
542 // %EltAddr = getelementptr i32* %1, i32 0
543 // %Elt = load i32* %EltAddr
544 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
546 Instruction
*ThenTerm
=
547 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
548 /*BranchWeights=*/nullptr, DTU
);
550 BasicBlock
*CondBlock
= ThenTerm
->getParent();
551 CondBlock
->setName("cond.load");
553 Builder
.SetInsertPoint(CondBlock
->getTerminator());
554 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
556 Builder
.CreateAlignedLoad(EltTy
, Ptr
, AlignVal
, "Load" + Twine(Idx
));
558 Builder
.CreateInsertElement(VResult
, Load
, Idx
, "Res" + Twine(Idx
));
560 // Create "else" block, fill it in the next iteration
561 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
562 NewIfBlock
->setName("else");
563 BasicBlock
*PrevIfBlock
= IfBlock
;
564 IfBlock
= NewIfBlock
;
566 // Create the phi to join the new and previous value.
567 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
568 PHINode
*Phi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
569 Phi
->addIncoming(NewVResult
, CondBlock
);
570 Phi
->addIncoming(VResult
, PrevIfBlock
);
574 CI
->replaceAllUsesWith(VResult
);
575 CI
->eraseFromParent();
580 // Translate a masked scatter intrinsic, like
581 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
583 // to a chain of basic blocks, that stores element one-by-one if
584 // the appropriate mask bit is set.
586 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
587 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
588 // br i1 %Mask0, label %cond.store, label %else
591 // %Elt0 = extractelement <16 x i32> %Src, i32 0
592 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
593 // store i32 %Elt0, i32* %Ptr0, align 4
597 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
598 // br i1 %Mask1, label %cond.store1, label %else2
601 // %Elt1 = extractelement <16 x i32> %Src, i32 1
602 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
603 // store i32 %Elt1, i32* %Ptr1, align 4
606 static void scalarizeMaskedScatter(const DataLayout
&DL
,
607 bool HasBranchDivergence
, CallInst
*CI
,
608 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
609 Value
*Src
= CI
->getArgOperand(0);
610 Value
*Ptrs
= CI
->getArgOperand(1);
611 Value
*Alignment
= CI
->getArgOperand(2);
612 Value
*Mask
= CI
->getArgOperand(3);
614 auto *SrcFVTy
= cast
<FixedVectorType
>(Src
->getType());
617 isa
<VectorType
>(Ptrs
->getType()) &&
618 isa
<PointerType
>(cast
<VectorType
>(Ptrs
->getType())->getElementType()) &&
619 "Vector of pointers is expected in masked scatter intrinsic");
621 IRBuilder
<> Builder(CI
->getContext());
622 Instruction
*InsertPt
= CI
;
623 Builder
.SetInsertPoint(InsertPt
);
624 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
626 MaybeAlign AlignVal
= cast
<ConstantInt
>(Alignment
)->getMaybeAlignValue();
627 unsigned VectorWidth
= SrcFVTy
->getNumElements();
629 // Shorten the way if the mask is a vector of constants.
630 if (isConstantIntVector(Mask
)) {
631 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
632 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
635 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
636 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
637 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
639 CI
->eraseFromParent();
643 // If the mask is not v1i1, use scalar bit test operations. This generates
644 // better results on X86 at least.
645 Value
*SclrMask
= nullptr;
646 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
647 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
648 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
651 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
652 // Fill the "else" block, created in the previous iteration
654 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
655 // %cond = icmp ne i16 %mask_1, 0
656 // br i1 %Mask1, label %cond.store, label %else
659 // %cond = extrectelement %mask, Idx
662 if (SclrMask
!= nullptr) {
663 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
664 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
665 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
666 Builder
.getIntN(VectorWidth
, 0));
668 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
671 // Create "cond" block
673 // %Elt1 = extractelement <16 x i32> %Src, i32 1
674 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
675 // %store i32 %Elt1, i32* %Ptr1
677 Instruction
*ThenTerm
=
678 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
679 /*BranchWeights=*/nullptr, DTU
);
681 BasicBlock
*CondBlock
= ThenTerm
->getParent();
682 CondBlock
->setName("cond.store");
684 Builder
.SetInsertPoint(CondBlock
->getTerminator());
685 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
686 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
687 Builder
.CreateAlignedStore(OneElt
, Ptr
, AlignVal
);
689 // Create "else" block, fill it in the next iteration
690 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
691 NewIfBlock
->setName("else");
693 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
695 CI
->eraseFromParent();
700 static void scalarizeMaskedExpandLoad(const DataLayout
&DL
,
701 bool HasBranchDivergence
, CallInst
*CI
,
702 DomTreeUpdater
*DTU
, bool &ModifiedDT
) {
703 Value
*Ptr
= CI
->getArgOperand(0);
704 Value
*Mask
= CI
->getArgOperand(1);
705 Value
*PassThru
= CI
->getArgOperand(2);
706 Align Alignment
= CI
->getParamAlign(0).valueOrOne();
708 auto *VecType
= cast
<FixedVectorType
>(CI
->getType());
710 Type
*EltTy
= VecType
->getElementType();
712 IRBuilder
<> Builder(CI
->getContext());
713 Instruction
*InsertPt
= CI
;
714 BasicBlock
*IfBlock
= CI
->getParent();
716 Builder
.SetInsertPoint(InsertPt
);
717 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
719 unsigned VectorWidth
= VecType
->getNumElements();
722 Value
*VResult
= PassThru
;
724 // Adjust alignment for the scalar instruction.
725 const Align AdjustedAlignment
=
726 commonAlignment(Alignment
, EltTy
->getPrimitiveSizeInBits() / 8);
728 // Shorten the way if the mask is a vector of constants.
729 // Create a build_vector pattern, with loads/poisons as necessary and then
730 // shuffle blend with the pass through value.
731 if (isConstantIntVector(Mask
)) {
732 unsigned MemIndex
= 0;
733 VResult
= PoisonValue::get(VecType
);
734 SmallVector
<int, 16> ShuffleMask(VectorWidth
, PoisonMaskElem
);
735 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
737 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue()) {
738 InsertElt
= PoisonValue::get(EltTy
);
739 ShuffleMask
[Idx
] = Idx
+ VectorWidth
;
742 Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
743 InsertElt
= Builder
.CreateAlignedLoad(EltTy
, NewPtr
, AdjustedAlignment
,
744 "Load" + Twine(Idx
));
745 ShuffleMask
[Idx
] = Idx
;
748 VResult
= Builder
.CreateInsertElement(VResult
, InsertElt
, Idx
,
751 VResult
= Builder
.CreateShuffleVector(VResult
, PassThru
, ShuffleMask
);
752 CI
->replaceAllUsesWith(VResult
);
753 CI
->eraseFromParent();
757 // If the mask is not v1i1, use scalar bit test operations. This generates
758 // better results on X86 at least. However, don't do this on GPUs or other
759 // machines with branch divergence, as there, each i1 takes up a register.
760 Value
*SclrMask
= nullptr;
761 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
762 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
763 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
766 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
767 // Fill the "else" block, created in the previous iteration
769 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771 // label %cond.load, label %else
774 // %cond = extrectelement %mask, Idx
778 if (SclrMask
!= nullptr) {
779 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
780 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
781 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
782 Builder
.getIntN(VectorWidth
, 0));
784 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
787 // Create "cond" block
789 // %EltAddr = getelementptr i32* %1, i32 0
790 // %Elt = load i32* %EltAddr
791 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
793 Instruction
*ThenTerm
=
794 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
795 /*BranchWeights=*/nullptr, DTU
);
797 BasicBlock
*CondBlock
= ThenTerm
->getParent();
798 CondBlock
->setName("cond.load");
800 Builder
.SetInsertPoint(CondBlock
->getTerminator());
801 LoadInst
*Load
= Builder
.CreateAlignedLoad(EltTy
, Ptr
, AdjustedAlignment
);
802 Value
*NewVResult
= Builder
.CreateInsertElement(VResult
, Load
, Idx
);
804 // Move the pointer if there are more blocks to come.
806 if ((Idx
+ 1) != VectorWidth
)
807 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
809 // Create "else" block, fill it in the next iteration
810 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
811 NewIfBlock
->setName("else");
812 BasicBlock
*PrevIfBlock
= IfBlock
;
813 IfBlock
= NewIfBlock
;
815 // Create the phi to join the new and previous value.
816 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
817 PHINode
*ResultPhi
= Builder
.CreatePHI(VecType
, 2, "res.phi.else");
818 ResultPhi
->addIncoming(NewVResult
, CondBlock
);
819 ResultPhi
->addIncoming(VResult
, PrevIfBlock
);
822 // Add a PHI for the pointer if this isn't the last iteration.
823 if ((Idx
+ 1) != VectorWidth
) {
824 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
825 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
826 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
831 CI
->replaceAllUsesWith(VResult
);
832 CI
->eraseFromParent();
837 static void scalarizeMaskedCompressStore(const DataLayout
&DL
,
838 bool HasBranchDivergence
, CallInst
*CI
,
841 Value
*Src
= CI
->getArgOperand(0);
842 Value
*Ptr
= CI
->getArgOperand(1);
843 Value
*Mask
= CI
->getArgOperand(2);
844 Align Alignment
= CI
->getParamAlign(1).valueOrOne();
846 auto *VecType
= cast
<FixedVectorType
>(Src
->getType());
848 IRBuilder
<> Builder(CI
->getContext());
849 Instruction
*InsertPt
= CI
;
850 BasicBlock
*IfBlock
= CI
->getParent();
852 Builder
.SetInsertPoint(InsertPt
);
853 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
855 Type
*EltTy
= VecType
->getElementType();
857 // Adjust alignment for the scalar instruction.
858 const Align AdjustedAlignment
=
859 commonAlignment(Alignment
, EltTy
->getPrimitiveSizeInBits() / 8);
861 unsigned VectorWidth
= VecType
->getNumElements();
863 // Shorten the way if the mask is a vector of constants.
864 if (isConstantIntVector(Mask
)) {
865 unsigned MemIndex
= 0;
866 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
867 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
870 Builder
.CreateExtractElement(Src
, Idx
, "Elt" + Twine(Idx
));
871 Value
*NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, MemIndex
);
872 Builder
.CreateAlignedStore(OneElt
, NewPtr
, AdjustedAlignment
);
875 CI
->eraseFromParent();
879 // If the mask is not v1i1, use scalar bit test operations. This generates
880 // better results on X86 at least. However, don't do this on GPUs or other
881 // machines with branch divergence, as there, each i1 takes up a register.
882 Value
*SclrMask
= nullptr;
883 if (VectorWidth
!= 1 && !HasBranchDivergence
) {
884 Type
*SclrMaskTy
= Builder
.getIntNTy(VectorWidth
);
885 SclrMask
= Builder
.CreateBitCast(Mask
, SclrMaskTy
, "scalar_mask");
888 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
889 // Fill the "else" block, created in the previous iteration
891 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
892 // br i1 %mask_1, label %cond.store, label %else
895 // %cond = extrectelement %mask, Idx
898 if (SclrMask
!= nullptr) {
899 Value
*Mask
= Builder
.getInt(APInt::getOneBitSet(
900 VectorWidth
, adjustForEndian(DL
, VectorWidth
, Idx
)));
901 Predicate
= Builder
.CreateICmpNE(Builder
.CreateAnd(SclrMask
, Mask
),
902 Builder
.getIntN(VectorWidth
, 0));
904 Predicate
= Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
907 // Create "cond" block
909 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
910 // %EltAddr = getelementptr i32* %1, i32 0
911 // %store i32 %OneElt, i32* %EltAddr
913 Instruction
*ThenTerm
=
914 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
915 /*BranchWeights=*/nullptr, DTU
);
917 BasicBlock
*CondBlock
= ThenTerm
->getParent();
918 CondBlock
->setName("cond.store");
920 Builder
.SetInsertPoint(CondBlock
->getTerminator());
921 Value
*OneElt
= Builder
.CreateExtractElement(Src
, Idx
);
922 Builder
.CreateAlignedStore(OneElt
, Ptr
, AdjustedAlignment
);
924 // Move the pointer if there are more blocks to come.
926 if ((Idx
+ 1) != VectorWidth
)
927 NewPtr
= Builder
.CreateConstInBoundsGEP1_32(EltTy
, Ptr
, 1);
929 // Create "else" block, fill it in the next iteration
930 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
931 NewIfBlock
->setName("else");
932 BasicBlock
*PrevIfBlock
= IfBlock
;
933 IfBlock
= NewIfBlock
;
935 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
937 // Add a PHI for the pointer if this isn't the last iteration.
938 if ((Idx
+ 1) != VectorWidth
) {
939 PHINode
*PtrPhi
= Builder
.CreatePHI(Ptr
->getType(), 2, "ptr.phi.else");
940 PtrPhi
->addIncoming(NewPtr
, CondBlock
);
941 PtrPhi
->addIncoming(Ptr
, PrevIfBlock
);
945 CI
->eraseFromParent();
950 static void scalarizeMaskedVectorHistogram(const DataLayout
&DL
, CallInst
*CI
,
953 // If we extend histogram to return a result someday (like the updated vector)
954 // then we'll need to support it here.
955 assert(CI
->getType()->isVoidTy() && "Histogram with non-void return.");
956 Value
*Ptrs
= CI
->getArgOperand(0);
957 Value
*Inc
= CI
->getArgOperand(1);
958 Value
*Mask
= CI
->getArgOperand(2);
960 auto *AddrType
= cast
<FixedVectorType
>(Ptrs
->getType());
961 Type
*EltTy
= Inc
->getType();
963 IRBuilder
<> Builder(CI
->getContext());
964 Instruction
*InsertPt
= CI
;
965 Builder
.SetInsertPoint(InsertPt
);
967 Builder
.SetCurrentDebugLocation(CI
->getDebugLoc());
969 // FIXME: Do we need to add an alignment parameter to the intrinsic?
970 unsigned VectorWidth
= AddrType
->getNumElements();
972 // Shorten the way if the mask is a vector of constants.
973 if (isConstantIntVector(Mask
)) {
974 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
975 if (cast
<Constant
>(Mask
)->getAggregateElement(Idx
)->isNullValue())
977 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
978 LoadInst
*Load
= Builder
.CreateLoad(EltTy
, Ptr
, "Load" + Twine(Idx
));
979 Value
*Add
= Builder
.CreateAdd(Load
, Inc
);
980 Builder
.CreateStore(Add
, Ptr
);
982 CI
->eraseFromParent();
986 for (unsigned Idx
= 0; Idx
< VectorWidth
; ++Idx
) {
988 Builder
.CreateExtractElement(Mask
, Idx
, "Mask" + Twine(Idx
));
990 Instruction
*ThenTerm
=
991 SplitBlockAndInsertIfThen(Predicate
, InsertPt
, /*Unreachable=*/false,
992 /*BranchWeights=*/nullptr, DTU
);
994 BasicBlock
*CondBlock
= ThenTerm
->getParent();
995 CondBlock
->setName("cond.histogram.update");
997 Builder
.SetInsertPoint(CondBlock
->getTerminator());
998 Value
*Ptr
= Builder
.CreateExtractElement(Ptrs
, Idx
, "Ptr" + Twine(Idx
));
999 LoadInst
*Load
= Builder
.CreateLoad(EltTy
, Ptr
, "Load" + Twine(Idx
));
1000 Value
*Add
= Builder
.CreateAdd(Load
, Inc
);
1001 Builder
.CreateStore(Add
, Ptr
);
1003 // Create "else" block, fill it in the next iteration
1004 BasicBlock
*NewIfBlock
= ThenTerm
->getSuccessor(0);
1005 NewIfBlock
->setName("else");
1006 Builder
.SetInsertPoint(NewIfBlock
, NewIfBlock
->begin());
1009 CI
->eraseFromParent();
1013 static bool runImpl(Function
&F
, const TargetTransformInfo
&TTI
,
1014 DominatorTree
*DT
) {
1015 std::optional
<DomTreeUpdater
> DTU
;
1017 DTU
.emplace(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
1019 bool EverMadeChange
= false;
1020 bool MadeChange
= true;
1021 auto &DL
= F
.getDataLayout();
1022 bool HasBranchDivergence
= TTI
.hasBranchDivergence(&F
);
1023 while (MadeChange
) {
1025 for (BasicBlock
&BB
: llvm::make_early_inc_range(F
)) {
1026 bool ModifiedDTOnIteration
= false;
1027 MadeChange
|= optimizeBlock(BB
, ModifiedDTOnIteration
, TTI
, DL
,
1028 HasBranchDivergence
, DTU
? &*DTU
: nullptr);
1030 // Restart BB iteration if the dominator tree of the Function was changed
1031 if (ModifiedDTOnIteration
)
1035 EverMadeChange
|= MadeChange
;
1037 return EverMadeChange
;
1040 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function
&F
) {
1041 auto &TTI
= getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
1042 DominatorTree
*DT
= nullptr;
1043 if (auto *DTWP
= getAnalysisIfAvailable
<DominatorTreeWrapperPass
>())
1044 DT
= &DTWP
->getDomTree();
1045 return runImpl(F
, TTI
, DT
);
1049 ScalarizeMaskedMemIntrinPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
1050 auto &TTI
= AM
.getResult
<TargetIRAnalysis
>(F
);
1051 auto *DT
= AM
.getCachedResult
<DominatorTreeAnalysis
>(F
);
1052 if (!runImpl(F
, TTI
, DT
))
1053 return PreservedAnalyses::all();
1054 PreservedAnalyses PA
;
1055 PA
.preserve
<TargetIRAnalysis
>();
1056 PA
.preserve
<DominatorTreeAnalysis
>();
1060 static bool optimizeBlock(BasicBlock
&BB
, bool &ModifiedDT
,
1061 const TargetTransformInfo
&TTI
, const DataLayout
&DL
,
1062 bool HasBranchDivergence
, DomTreeUpdater
*DTU
) {
1063 bool MadeChange
= false;
1065 BasicBlock::iterator CurInstIterator
= BB
.begin();
1066 while (CurInstIterator
!= BB
.end()) {
1067 if (CallInst
*CI
= dyn_cast
<CallInst
>(&*CurInstIterator
++))
1069 optimizeCallInst(CI
, ModifiedDT
, TTI
, DL
, HasBranchDivergence
, DTU
);
1077 static bool optimizeCallInst(CallInst
*CI
, bool &ModifiedDT
,
1078 const TargetTransformInfo
&TTI
,
1079 const DataLayout
&DL
, bool HasBranchDivergence
,
1080 DomTreeUpdater
*DTU
) {
1081 IntrinsicInst
*II
= dyn_cast
<IntrinsicInst
>(CI
);
1083 // The scalarization code below does not work for scalable vectors.
1084 if (isa
<ScalableVectorType
>(II
->getType()) ||
1086 [](Value
*V
) { return isa
<ScalableVectorType
>(V
->getType()); }))
1088 switch (II
->getIntrinsicID()) {
1091 case Intrinsic::experimental_vector_histogram_add
:
1092 if (TTI
.isLegalMaskedVectorHistogram(CI
->getArgOperand(0)->getType(),
1093 CI
->getArgOperand(1)->getType()))
1095 scalarizeMaskedVectorHistogram(DL
, CI
, DTU
, ModifiedDT
);
1097 case Intrinsic::masked_load
:
1098 // Scalarize unsupported vector masked load
1099 if (TTI
.isLegalMaskedLoad(
1101 cast
<ConstantInt
>(CI
->getArgOperand(1))->getAlignValue()))
1103 scalarizeMaskedLoad(DL
, HasBranchDivergence
, CI
, DTU
, ModifiedDT
);
1105 case Intrinsic::masked_store
:
1106 if (TTI
.isLegalMaskedStore(
1107 CI
->getArgOperand(0)->getType(),
1108 cast
<ConstantInt
>(CI
->getArgOperand(2))->getAlignValue()))
1110 scalarizeMaskedStore(DL
, HasBranchDivergence
, CI
, DTU
, ModifiedDT
);
1112 case Intrinsic::masked_gather
: {
1114 cast
<ConstantInt
>(CI
->getArgOperand(1))->getMaybeAlignValue();
1115 Type
*LoadTy
= CI
->getType();
1116 Align Alignment
= DL
.getValueOrABITypeAlignment(MA
,
1117 LoadTy
->getScalarType());
1118 if (TTI
.isLegalMaskedGather(LoadTy
, Alignment
) &&
1119 !TTI
.forceScalarizeMaskedGather(cast
<VectorType
>(LoadTy
), Alignment
))
1121 scalarizeMaskedGather(DL
, HasBranchDivergence
, CI
, DTU
, ModifiedDT
);
1124 case Intrinsic::masked_scatter
: {
1126 cast
<ConstantInt
>(CI
->getArgOperand(2))->getMaybeAlignValue();
1127 Type
*StoreTy
= CI
->getArgOperand(0)->getType();
1128 Align Alignment
= DL
.getValueOrABITypeAlignment(MA
,
1129 StoreTy
->getScalarType());
1130 if (TTI
.isLegalMaskedScatter(StoreTy
, Alignment
) &&
1131 !TTI
.forceScalarizeMaskedScatter(cast
<VectorType
>(StoreTy
),
1134 scalarizeMaskedScatter(DL
, HasBranchDivergence
, CI
, DTU
, ModifiedDT
);
1137 case Intrinsic::masked_expandload
:
1138 if (TTI
.isLegalMaskedExpandLoad(
1140 CI
->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
1142 scalarizeMaskedExpandLoad(DL
, HasBranchDivergence
, CI
, DTU
, ModifiedDT
);
1144 case Intrinsic::masked_compressstore
:
1145 if (TTI
.isLegalMaskedCompressStore(
1146 CI
->getArgOperand(0)->getType(),
1147 CI
->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
1149 scalarizeMaskedCompressStore(DL
, HasBranchDivergence
, CI
, DTU
,