[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Transforms / Scalar / ScalarizeMaskedMemIntrin.cpp
blobca288a533f46a6ad01de1d4cff01772aa5b1232c
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 #include <algorithm>
40 #include <cassert>
42 using namespace llvm;
44 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
46 namespace {
48 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
49 public:
50 static char ID; // Pass identification, replacement for typeid
52 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
53 initializeScalarizeMaskedMemIntrinLegacyPassPass(
54 *PassRegistry::getPassRegistry());
57 bool runOnFunction(Function &F) override;
59 StringRef getPassName() const override {
60 return "Scalarize Masked Memory Intrinsics";
63 void getAnalysisUsage(AnalysisUsage &AU) const override {
64 AU.addRequired<TargetTransformInfoWrapperPass>();
65 AU.addPreserved<DominatorTreeWrapperPass>();
69 } // end anonymous namespace
71 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
72 const TargetTransformInfo &TTI, const DataLayout &DL,
73 DomTreeUpdater *DTU);
74 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75 const TargetTransformInfo &TTI,
76 const DataLayout &DL, DomTreeUpdater *DTU);
78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81 "Scalarize unsupported masked memory intrinsics", false,
82 false)
83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86 "Scalarize unsupported masked memory intrinsics", false,
87 false)
89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90 return new ScalarizeMaskedMemIntrinLegacyPass();
93 static bool isConstantIntVector(Value *Mask) {
94 Constant *C = dyn_cast<Constant>(Mask);
95 if (!C)
96 return false;
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt = C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
102 return false;
105 return true;
108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109 unsigned Idx) {
110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
113 // Translate a masked load intrinsic like
114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115 // <16 x i1> %mask, <16 x i32> %passthru)
116 // to a chain of basic blocks, with loading element one-by-one if
117 // the appropriate mask bit is set
119 // %1 = bitcast i8* %addr to i32*
120 // %2 = extractelement <16 x i1> %mask, i32 0
121 // br i1 %2, label %cond.load, label %else
123 // cond.load: ; preds = %0
124 // %3 = getelementptr i32* %1, i32 0
125 // %4 = load i32* %3
126 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127 // br label %else
129 // else: ; preds = %0, %cond.load
130 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
131 // %6 = extractelement <16 x i1> %mask, i32 1
132 // br i1 %6, label %cond.load1, label %else2
134 // cond.load1: ; preds = %else
135 // %7 = getelementptr i32* %1, i32 1
136 // %8 = load i32* %7
137 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138 // br label %else2
140 // else2: ; preds = %else, %cond.load1
141 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142 // %10 = extractelement <16 x i1> %mask, i32 2
143 // br i1 %10, label %cond.load4, label %else5
145 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
146 DomTreeUpdater *DTU, bool &ModifiedDT) {
147 Value *Ptr = CI->getArgOperand(0);
148 Value *Alignment = CI->getArgOperand(1);
149 Value *Mask = CI->getArgOperand(2);
150 Value *Src0 = CI->getArgOperand(3);
152 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153 VectorType *VecType = cast<FixedVectorType>(CI->getType());
155 Type *EltTy = VecType->getElementType();
157 IRBuilder<> Builder(CI->getContext());
158 Instruction *InsertPt = CI;
159 BasicBlock *IfBlock = CI->getParent();
161 Builder.SetInsertPoint(InsertPt);
162 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
164 // Short-cut if the mask is all-true.
165 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
166 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
167 CI->replaceAllUsesWith(NewI);
168 CI->eraseFromParent();
169 return;
172 // Adjust alignment for the scalar instruction.
173 const Align AdjustedAlignVal =
174 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
175 // Bitcast %addr from i8* to EltTy*
176 Type *NewPtrType =
177 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
178 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
179 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
181 // The result vector
182 Value *VResult = Src0;
184 if (isConstantIntVector(Mask)) {
185 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
187 continue;
188 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
189 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
190 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
192 CI->replaceAllUsesWith(VResult);
193 CI->eraseFromParent();
194 return;
197 // If the mask is not v1i1, use scalar bit test operations. This generates
198 // better results on X86 at least.
199 Value *SclrMask;
200 if (VectorWidth != 1) {
201 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
202 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
205 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
206 // Fill the "else" block, created in the previous iteration
208 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
209 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
210 // %cond = icmp ne i16 %mask_1, 0
211 // br i1 %mask_1, label %cond.load, label %else
213 Value *Predicate;
214 if (VectorWidth != 1) {
215 Value *Mask = Builder.getInt(APInt::getOneBitSet(
216 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
217 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
218 Builder.getIntN(VectorWidth, 0));
219 } else {
220 Predicate = Builder.CreateExtractElement(Mask, Idx);
223 // Create "cond" block
225 // %EltAddr = getelementptr i32* %1, i32 0
226 // %Elt = load i32* %EltAddr
227 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
229 Instruction *ThenTerm =
230 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
231 /*BranchWeights=*/nullptr, DTU);
233 BasicBlock *CondBlock = ThenTerm->getParent();
234 CondBlock->setName("cond.load");
236 Builder.SetInsertPoint(CondBlock->getTerminator());
237 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
238 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
239 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
241 // Create "else" block, fill it in the next iteration
242 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
243 NewIfBlock->setName("else");
244 BasicBlock *PrevIfBlock = IfBlock;
245 IfBlock = NewIfBlock;
247 // Create the phi to join the new and previous value.
248 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
249 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
250 Phi->addIncoming(NewVResult, CondBlock);
251 Phi->addIncoming(VResult, PrevIfBlock);
252 VResult = Phi;
255 CI->replaceAllUsesWith(VResult);
256 CI->eraseFromParent();
258 ModifiedDT = true;
261 // Translate a masked store intrinsic, like
262 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
263 // <16 x i1> %mask)
264 // to a chain of basic blocks, that stores element one-by-one if
265 // the appropriate mask bit is set
267 // %1 = bitcast i8* %addr to i32*
268 // %2 = extractelement <16 x i1> %mask, i32 0
269 // br i1 %2, label %cond.store, label %else
271 // cond.store: ; preds = %0
272 // %3 = extractelement <16 x i32> %val, i32 0
273 // %4 = getelementptr i32* %1, i32 0
274 // store i32 %3, i32* %4
275 // br label %else
277 // else: ; preds = %0, %cond.store
278 // %5 = extractelement <16 x i1> %mask, i32 1
279 // br i1 %5, label %cond.store1, label %else2
281 // cond.store1: ; preds = %else
282 // %6 = extractelement <16 x i32> %val, i32 1
283 // %7 = getelementptr i32* %1, i32 1
284 // store i32 %6, i32* %7
285 // br label %else2
286 // . . .
287 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
288 DomTreeUpdater *DTU, bool &ModifiedDT) {
289 Value *Src = CI->getArgOperand(0);
290 Value *Ptr = CI->getArgOperand(1);
291 Value *Alignment = CI->getArgOperand(2);
292 Value *Mask = CI->getArgOperand(3);
294 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
295 auto *VecType = cast<VectorType>(Src->getType());
297 Type *EltTy = VecType->getElementType();
299 IRBuilder<> Builder(CI->getContext());
300 Instruction *InsertPt = CI;
301 Builder.SetInsertPoint(InsertPt);
302 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
304 // Short-cut if the mask is all-true.
305 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
306 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
307 CI->eraseFromParent();
308 return;
311 // Adjust alignment for the scalar instruction.
312 const Align AdjustedAlignVal =
313 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
314 // Bitcast %addr from i8* to EltTy*
315 Type *NewPtrType =
316 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
317 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
318 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
320 if (isConstantIntVector(Mask)) {
321 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
322 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
323 continue;
324 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
325 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
326 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
328 CI->eraseFromParent();
329 return;
332 // If the mask is not v1i1, use scalar bit test operations. This generates
333 // better results on X86 at least.
334 Value *SclrMask;
335 if (VectorWidth != 1) {
336 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
337 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
340 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
341 // Fill the "else" block, created in the previous iteration
343 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
344 // %cond = icmp ne i16 %mask_1, 0
345 // br i1 %mask_1, label %cond.store, label %else
347 Value *Predicate;
348 if (VectorWidth != 1) {
349 Value *Mask = Builder.getInt(APInt::getOneBitSet(
350 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
351 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
352 Builder.getIntN(VectorWidth, 0));
353 } else {
354 Predicate = Builder.CreateExtractElement(Mask, Idx);
357 // Create "cond" block
359 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
360 // %EltAddr = getelementptr i32* %1, i32 0
361 // %store i32 %OneElt, i32* %EltAddr
363 Instruction *ThenTerm =
364 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
365 /*BranchWeights=*/nullptr, DTU);
367 BasicBlock *CondBlock = ThenTerm->getParent();
368 CondBlock->setName("cond.store");
370 Builder.SetInsertPoint(CondBlock->getTerminator());
371 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
372 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
373 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
375 // Create "else" block, fill it in the next iteration
376 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
377 NewIfBlock->setName("else");
379 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
381 CI->eraseFromParent();
383 ModifiedDT = true;
386 // Translate a masked gather intrinsic like
387 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
388 // <16 x i1> %Mask, <16 x i32> %Src)
389 // to a chain of basic blocks, with loading element one-by-one if
390 // the appropriate mask bit is set
392 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
393 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
394 // br i1 %Mask0, label %cond.load, label %else
396 // cond.load:
397 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
398 // %Load0 = load i32, i32* %Ptr0, align 4
399 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
400 // br label %else
402 // else:
403 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
404 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
405 // br i1 %Mask1, label %cond.load1, label %else2
407 // cond.load1:
408 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
409 // %Load1 = load i32, i32* %Ptr1, align 4
410 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
411 // br label %else2
412 // . . .
413 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
414 // ret <16 x i32> %Result
415 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
416 DomTreeUpdater *DTU, bool &ModifiedDT) {
417 Value *Ptrs = CI->getArgOperand(0);
418 Value *Alignment = CI->getArgOperand(1);
419 Value *Mask = CI->getArgOperand(2);
420 Value *Src0 = CI->getArgOperand(3);
422 auto *VecType = cast<FixedVectorType>(CI->getType());
423 Type *EltTy = VecType->getElementType();
425 IRBuilder<> Builder(CI->getContext());
426 Instruction *InsertPt = CI;
427 BasicBlock *IfBlock = CI->getParent();
428 Builder.SetInsertPoint(InsertPt);
429 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
431 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
433 // The result vector
434 Value *VResult = Src0;
435 unsigned VectorWidth = VecType->getNumElements();
437 // Shorten the way if the mask is a vector of constants.
438 if (isConstantIntVector(Mask)) {
439 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
440 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
441 continue;
442 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
443 LoadInst *Load =
444 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
445 VResult =
446 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
448 CI->replaceAllUsesWith(VResult);
449 CI->eraseFromParent();
450 return;
453 // If the mask is not v1i1, use scalar bit test operations. This generates
454 // better results on X86 at least.
455 Value *SclrMask;
456 if (VectorWidth != 1) {
457 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
458 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
461 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
462 // Fill the "else" block, created in the previous iteration
464 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
465 // %cond = icmp ne i16 %mask_1, 0
466 // br i1 %Mask1, label %cond.load, label %else
469 Value *Predicate;
470 if (VectorWidth != 1) {
471 Value *Mask = Builder.getInt(APInt::getOneBitSet(
472 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
473 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
474 Builder.getIntN(VectorWidth, 0));
475 } else {
476 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
479 // Create "cond" block
481 // %EltAddr = getelementptr i32* %1, i32 0
482 // %Elt = load i32* %EltAddr
483 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
485 Instruction *ThenTerm =
486 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
487 /*BranchWeights=*/nullptr, DTU);
489 BasicBlock *CondBlock = ThenTerm->getParent();
490 CondBlock->setName("cond.load");
492 Builder.SetInsertPoint(CondBlock->getTerminator());
493 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
494 LoadInst *Load =
495 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
496 Value *NewVResult =
497 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
499 // Create "else" block, fill it in the next iteration
500 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
501 NewIfBlock->setName("else");
502 BasicBlock *PrevIfBlock = IfBlock;
503 IfBlock = NewIfBlock;
505 // Create the phi to join the new and previous value.
506 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
507 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
508 Phi->addIncoming(NewVResult, CondBlock);
509 Phi->addIncoming(VResult, PrevIfBlock);
510 VResult = Phi;
513 CI->replaceAllUsesWith(VResult);
514 CI->eraseFromParent();
516 ModifiedDT = true;
519 // Translate a masked scatter intrinsic, like
520 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
521 // <16 x i1> %Mask)
522 // to a chain of basic blocks, that stores element one-by-one if
523 // the appropriate mask bit is set.
525 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
526 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
527 // br i1 %Mask0, label %cond.store, label %else
529 // cond.store:
530 // %Elt0 = extractelement <16 x i32> %Src, i32 0
531 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
532 // store i32 %Elt0, i32* %Ptr0, align 4
533 // br label %else
535 // else:
536 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
537 // br i1 %Mask1, label %cond.store1, label %else2
539 // cond.store1:
540 // %Elt1 = extractelement <16 x i32> %Src, i32 1
541 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
542 // store i32 %Elt1, i32* %Ptr1, align 4
543 // br label %else2
544 // . . .
545 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
546 DomTreeUpdater *DTU, bool &ModifiedDT) {
547 Value *Src = CI->getArgOperand(0);
548 Value *Ptrs = CI->getArgOperand(1);
549 Value *Alignment = CI->getArgOperand(2);
550 Value *Mask = CI->getArgOperand(3);
552 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
554 assert(
555 isa<VectorType>(Ptrs->getType()) &&
556 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
557 "Vector of pointers is expected in masked scatter intrinsic");
559 IRBuilder<> Builder(CI->getContext());
560 Instruction *InsertPt = CI;
561 Builder.SetInsertPoint(InsertPt);
562 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
564 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
565 unsigned VectorWidth = SrcFVTy->getNumElements();
567 // Shorten the way if the mask is a vector of constants.
568 if (isConstantIntVector(Mask)) {
569 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
570 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
571 continue;
572 Value *OneElt =
573 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
574 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
577 CI->eraseFromParent();
578 return;
581 // If the mask is not v1i1, use scalar bit test operations. This generates
582 // better results on X86 at least.
583 Value *SclrMask;
584 if (VectorWidth != 1) {
585 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
586 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
589 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
590 // Fill the "else" block, created in the previous iteration
592 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
593 // %cond = icmp ne i16 %mask_1, 0
594 // br i1 %Mask1, label %cond.store, label %else
596 Value *Predicate;
597 if (VectorWidth != 1) {
598 Value *Mask = Builder.getInt(APInt::getOneBitSet(
599 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
600 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
601 Builder.getIntN(VectorWidth, 0));
602 } else {
603 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
606 // Create "cond" block
608 // %Elt1 = extractelement <16 x i32> %Src, i32 1
609 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
610 // %store i32 %Elt1, i32* %Ptr1
612 Instruction *ThenTerm =
613 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
614 /*BranchWeights=*/nullptr, DTU);
616 BasicBlock *CondBlock = ThenTerm->getParent();
617 CondBlock->setName("cond.store");
619 Builder.SetInsertPoint(CondBlock->getTerminator());
620 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
621 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
622 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
624 // Create "else" block, fill it in the next iteration
625 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
626 NewIfBlock->setName("else");
628 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
630 CI->eraseFromParent();
632 ModifiedDT = true;
635 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
636 DomTreeUpdater *DTU, bool &ModifiedDT) {
637 Value *Ptr = CI->getArgOperand(0);
638 Value *Mask = CI->getArgOperand(1);
639 Value *PassThru = CI->getArgOperand(2);
641 auto *VecType = cast<FixedVectorType>(CI->getType());
643 Type *EltTy = VecType->getElementType();
645 IRBuilder<> Builder(CI->getContext());
646 Instruction *InsertPt = CI;
647 BasicBlock *IfBlock = CI->getParent();
649 Builder.SetInsertPoint(InsertPt);
650 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
652 unsigned VectorWidth = VecType->getNumElements();
654 // The result vector
655 Value *VResult = PassThru;
657 // Shorten the way if the mask is a vector of constants.
658 // Create a build_vector pattern, with loads/undefs as necessary and then
659 // shuffle blend with the pass through value.
660 if (isConstantIntVector(Mask)) {
661 unsigned MemIndex = 0;
662 VResult = UndefValue::get(VecType);
663 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
664 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
665 Value *InsertElt;
666 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
667 InsertElt = UndefValue::get(EltTy);
668 ShuffleMask[Idx] = Idx + VectorWidth;
669 } else {
670 Value *NewPtr =
671 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
672 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
673 "Load" + Twine(Idx));
674 ShuffleMask[Idx] = Idx;
675 ++MemIndex;
677 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
678 "Res" + Twine(Idx));
680 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
681 CI->replaceAllUsesWith(VResult);
682 CI->eraseFromParent();
683 return;
686 // If the mask is not v1i1, use scalar bit test operations. This generates
687 // better results on X86 at least.
688 Value *SclrMask;
689 if (VectorWidth != 1) {
690 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
691 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
694 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
695 // Fill the "else" block, created in the previous iteration
697 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
698 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
699 // br i1 %mask_1, label %cond.load, label %else
702 Value *Predicate;
703 if (VectorWidth != 1) {
704 Value *Mask = Builder.getInt(APInt::getOneBitSet(
705 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
706 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
707 Builder.getIntN(VectorWidth, 0));
708 } else {
709 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
712 // Create "cond" block
714 // %EltAddr = getelementptr i32* %1, i32 0
715 // %Elt = load i32* %EltAddr
716 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
718 Instruction *ThenTerm =
719 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
720 /*BranchWeights=*/nullptr, DTU);
722 BasicBlock *CondBlock = ThenTerm->getParent();
723 CondBlock->setName("cond.load");
725 Builder.SetInsertPoint(CondBlock->getTerminator());
726 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
727 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
729 // Move the pointer if there are more blocks to come.
730 Value *NewPtr;
731 if ((Idx + 1) != VectorWidth)
732 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
734 // Create "else" block, fill it in the next iteration
735 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
736 NewIfBlock->setName("else");
737 BasicBlock *PrevIfBlock = IfBlock;
738 IfBlock = NewIfBlock;
740 // Create the phi to join the new and previous value.
741 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
742 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
743 ResultPhi->addIncoming(NewVResult, CondBlock);
744 ResultPhi->addIncoming(VResult, PrevIfBlock);
745 VResult = ResultPhi;
747 // Add a PHI for the pointer if this isn't the last iteration.
748 if ((Idx + 1) != VectorWidth) {
749 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
750 PtrPhi->addIncoming(NewPtr, CondBlock);
751 PtrPhi->addIncoming(Ptr, PrevIfBlock);
752 Ptr = PtrPhi;
756 CI->replaceAllUsesWith(VResult);
757 CI->eraseFromParent();
759 ModifiedDT = true;
762 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
763 DomTreeUpdater *DTU,
764 bool &ModifiedDT) {
765 Value *Src = CI->getArgOperand(0);
766 Value *Ptr = CI->getArgOperand(1);
767 Value *Mask = CI->getArgOperand(2);
769 auto *VecType = cast<FixedVectorType>(Src->getType());
771 IRBuilder<> Builder(CI->getContext());
772 Instruction *InsertPt = CI;
773 BasicBlock *IfBlock = CI->getParent();
775 Builder.SetInsertPoint(InsertPt);
776 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
778 Type *EltTy = VecType->getElementType();
780 unsigned VectorWidth = VecType->getNumElements();
782 // Shorten the way if the mask is a vector of constants.
783 if (isConstantIntVector(Mask)) {
784 unsigned MemIndex = 0;
785 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
786 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
787 continue;
788 Value *OneElt =
789 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
790 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
791 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
792 ++MemIndex;
794 CI->eraseFromParent();
795 return;
798 // If the mask is not v1i1, use scalar bit test operations. This generates
799 // better results on X86 at least.
800 Value *SclrMask;
801 if (VectorWidth != 1) {
802 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
803 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
806 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
807 // Fill the "else" block, created in the previous iteration
809 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
810 // br i1 %mask_1, label %cond.store, label %else
812 Value *Predicate;
813 if (VectorWidth != 1) {
814 Value *Mask = Builder.getInt(APInt::getOneBitSet(
815 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
816 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
817 Builder.getIntN(VectorWidth, 0));
818 } else {
819 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
822 // Create "cond" block
824 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
825 // %EltAddr = getelementptr i32* %1, i32 0
826 // %store i32 %OneElt, i32* %EltAddr
828 Instruction *ThenTerm =
829 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
830 /*BranchWeights=*/nullptr, DTU);
832 BasicBlock *CondBlock = ThenTerm->getParent();
833 CondBlock->setName("cond.store");
835 Builder.SetInsertPoint(CondBlock->getTerminator());
836 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
837 Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
839 // Move the pointer if there are more blocks to come.
840 Value *NewPtr;
841 if ((Idx + 1) != VectorWidth)
842 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
844 // Create "else" block, fill it in the next iteration
845 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
846 NewIfBlock->setName("else");
847 BasicBlock *PrevIfBlock = IfBlock;
848 IfBlock = NewIfBlock;
850 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
852 // Add a PHI for the pointer if this isn't the last iteration.
853 if ((Idx + 1) != VectorWidth) {
854 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
855 PtrPhi->addIncoming(NewPtr, CondBlock);
856 PtrPhi->addIncoming(Ptr, PrevIfBlock);
857 Ptr = PtrPhi;
860 CI->eraseFromParent();
862 ModifiedDT = true;
865 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
866 DominatorTree *DT) {
867 Optional<DomTreeUpdater> DTU;
868 if (DT)
869 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
871 bool EverMadeChange = false;
872 bool MadeChange = true;
873 auto &DL = F.getParent()->getDataLayout();
874 while (MadeChange) {
875 MadeChange = false;
876 for (Function::iterator I = F.begin(); I != F.end();) {
877 BasicBlock *BB = &*I++;
878 bool ModifiedDTOnIteration = false;
879 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL,
880 DTU.hasValue() ? DTU.getPointer() : nullptr);
883 // Restart BB iteration if the dominator tree of the Function was changed
884 if (ModifiedDTOnIteration)
885 break;
888 EverMadeChange |= MadeChange;
890 return EverMadeChange;
893 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
894 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
895 DominatorTree *DT = nullptr;
896 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
897 DT = &DTWP->getDomTree();
898 return runImpl(F, TTI, DT);
901 PreservedAnalyses
902 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
903 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
904 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
905 if (!runImpl(F, TTI, DT))
906 return PreservedAnalyses::all();
907 PreservedAnalyses PA;
908 PA.preserve<TargetIRAnalysis>();
909 PA.preserve<DominatorTreeAnalysis>();
910 return PA;
913 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
914 const TargetTransformInfo &TTI, const DataLayout &DL,
915 DomTreeUpdater *DTU) {
916 bool MadeChange = false;
918 BasicBlock::iterator CurInstIterator = BB.begin();
919 while (CurInstIterator != BB.end()) {
920 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
921 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
922 if (ModifiedDT)
923 return true;
926 return MadeChange;
929 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
930 const TargetTransformInfo &TTI,
931 const DataLayout &DL, DomTreeUpdater *DTU) {
932 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
933 if (II) {
934 // The scalarization code below does not work for scalable vectors.
935 if (isa<ScalableVectorType>(II->getType()) ||
936 any_of(II->arg_operands(),
937 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
938 return false;
940 switch (II->getIntrinsicID()) {
941 default:
942 break;
943 case Intrinsic::masked_load:
944 // Scalarize unsupported vector masked load
945 if (TTI.isLegalMaskedLoad(
946 CI->getType(),
947 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
948 return false;
949 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
950 return true;
951 case Intrinsic::masked_store:
952 if (TTI.isLegalMaskedStore(
953 CI->getArgOperand(0)->getType(),
954 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
955 return false;
956 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
957 return true;
958 case Intrinsic::masked_gather: {
959 MaybeAlign MA =
960 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
961 Type *LoadTy = CI->getType();
962 Align Alignment = DL.getValueOrABITypeAlignment(MA,
963 LoadTy->getScalarType());
964 if (TTI.isLegalMaskedGather(LoadTy, Alignment))
965 return false;
966 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
967 return true;
969 case Intrinsic::masked_scatter: {
970 MaybeAlign MA =
971 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
972 Type *StoreTy = CI->getArgOperand(0)->getType();
973 Align Alignment = DL.getValueOrABITypeAlignment(MA,
974 StoreTy->getScalarType());
975 if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
976 return false;
977 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
978 return true;
980 case Intrinsic::masked_expandload:
981 if (TTI.isLegalMaskedExpandLoad(CI->getType()))
982 return false;
983 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
984 return true;
985 case Intrinsic::masked_compressstore:
986 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
987 return false;
988 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
989 return true;
993 return false;