[x86] fix assert with horizontal math + broadcast of vector (PR43402)
[llvm-core.git] / lib / CodeGen / ScalarizeMaskedMemIntrin.cpp
blob515582640ed4f2adecd21e33a6fa3b3317a5a36e
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
37 using namespace llvm;
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 namespace {
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44 const TargetTransformInfo *TTI = nullptr;
46 public:
47 static char ID; // Pass identification, replacement for typeid
49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
50 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function &F) override;
55 StringRef getPassName() const override {
56 return "Scalarize Masked Memory Intrinsics";
59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<TargetTransformInfoWrapperPass>();
63 private:
64 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
68 } // end anonymous namespace
70 char ScalarizeMaskedMemIntrin::ID = 0;
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73 "Scalarize unsupported masked memory intrinsics", false, false)
75 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
79 static bool isConstantIntVector(Value *Mask) {
80 Constant *C = dyn_cast<Constant>(Mask);
81 if (!C)
82 return false;
84 unsigned NumElts = Mask->getType()->getVectorNumElements();
85 for (unsigned i = 0; i != NumElts; ++i) {
86 Constant *CElt = C->getAggregateElement(i);
87 if (!CElt || !isa<ConstantInt>(CElt))
88 return false;
91 return true;
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127 Value *Ptr = CI->getArgOperand(0);
128 Value *Alignment = CI->getArgOperand(1);
129 Value *Mask = CI->getArgOperand(2);
130 Value *Src0 = CI->getArgOperand(3);
132 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133 VectorType *VecType = cast<VectorType>(CI->getType());
135 Type *EltTy = VecType->getElementType();
137 IRBuilder<> Builder(CI->getContext());
138 Instruction *InsertPt = CI;
139 BasicBlock *IfBlock = CI->getParent();
141 Builder.SetInsertPoint(InsertPt);
142 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144 // Short-cut if the mask is all-true.
145 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147 CI->replaceAllUsesWith(NewI);
148 CI->eraseFromParent();
149 return;
152 // Adjust alignment for the scalar instruction.
153 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154 // Bitcast %addr from i8* to EltTy*
155 Type *NewPtrType =
156 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158 unsigned VectorWidth = VecType->getNumElements();
160 // The result vector
161 Value *VResult = Src0;
163 if (isConstantIntVector(Mask)) {
164 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166 continue;
167 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
171 CI->replaceAllUsesWith(VResult);
172 CI->eraseFromParent();
173 return;
176 // If the mask is not v1i1, use scalar bit test operations. This generates
177 // better results on X86 at least.
178 Value *SclrMask;
179 if (VectorWidth != 1) {
180 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
181 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
184 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185 // Fill the "else" block, created in the previous iteration
187 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
188 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
189 // %cond = icmp ne i16 %mask_1, 0
190 // br i1 %mask_1, label %cond.load, label %else
192 Value *Predicate;
193 if (VectorWidth != 1) {
194 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
195 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
196 Builder.getIntN(VectorWidth, 0));
197 } else {
198 Predicate = Builder.CreateExtractElement(Mask, Idx);
201 // Create "cond" block
203 // %EltAddr = getelementptr i32* %1, i32 0
204 // %Elt = load i32* %EltAddr
205 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
207 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
208 "cond.load");
209 Builder.SetInsertPoint(InsertPt);
211 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
212 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
213 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
215 // Create "else" block, fill it in the next iteration
216 BasicBlock *NewIfBlock =
217 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
218 Builder.SetInsertPoint(InsertPt);
219 Instruction *OldBr = IfBlock->getTerminator();
220 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
221 OldBr->eraseFromParent();
222 BasicBlock *PrevIfBlock = IfBlock;
223 IfBlock = NewIfBlock;
225 // Create the phi to join the new and previous value.
226 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
227 Phi->addIncoming(NewVResult, CondBlock);
228 Phi->addIncoming(VResult, PrevIfBlock);
229 VResult = Phi;
232 CI->replaceAllUsesWith(VResult);
233 CI->eraseFromParent();
235 ModifiedDT = true;
238 // Translate a masked store intrinsic, like
239 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
240 // <16 x i1> %mask)
241 // to a chain of basic blocks, that stores element one-by-one if
242 // the appropriate mask bit is set
244 // %1 = bitcast i8* %addr to i32*
245 // %2 = extractelement <16 x i1> %mask, i32 0
246 // br i1 %2, label %cond.store, label %else
248 // cond.store: ; preds = %0
249 // %3 = extractelement <16 x i32> %val, i32 0
250 // %4 = getelementptr i32* %1, i32 0
251 // store i32 %3, i32* %4
252 // br label %else
254 // else: ; preds = %0, %cond.store
255 // %5 = extractelement <16 x i1> %mask, i32 1
256 // br i1 %5, label %cond.store1, label %else2
258 // cond.store1: ; preds = %else
259 // %6 = extractelement <16 x i32> %val, i32 1
260 // %7 = getelementptr i32* %1, i32 1
261 // store i32 %6, i32* %7
262 // br label %else2
263 // . . .
264 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
265 Value *Src = CI->getArgOperand(0);
266 Value *Ptr = CI->getArgOperand(1);
267 Value *Alignment = CI->getArgOperand(2);
268 Value *Mask = CI->getArgOperand(3);
270 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
271 VectorType *VecType = cast<VectorType>(Src->getType());
273 Type *EltTy = VecType->getElementType();
275 IRBuilder<> Builder(CI->getContext());
276 Instruction *InsertPt = CI;
277 BasicBlock *IfBlock = CI->getParent();
278 Builder.SetInsertPoint(InsertPt);
279 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
281 // Short-cut if the mask is all-true.
282 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
283 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
284 CI->eraseFromParent();
285 return;
288 // Adjust alignment for the scalar instruction.
289 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
290 // Bitcast %addr from i8* to EltTy*
291 Type *NewPtrType =
292 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
293 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
294 unsigned VectorWidth = VecType->getNumElements();
296 if (isConstantIntVector(Mask)) {
297 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
298 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
299 continue;
300 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
301 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
302 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
304 CI->eraseFromParent();
305 return;
308 // If the mask is not v1i1, use scalar bit test operations. This generates
309 // better results on X86 at least.
310 Value *SclrMask;
311 if (VectorWidth != 1) {
312 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
313 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
316 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
317 // Fill the "else" block, created in the previous iteration
319 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
320 // %cond = icmp ne i16 %mask_1, 0
321 // br i1 %mask_1, label %cond.store, label %else
323 Value *Predicate;
324 if (VectorWidth != 1) {
325 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
326 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
327 Builder.getIntN(VectorWidth, 0));
328 } else {
329 Predicate = Builder.CreateExtractElement(Mask, Idx);
332 // Create "cond" block
334 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
335 // %EltAddr = getelementptr i32* %1, i32 0
336 // %store i32 %OneElt, i32* %EltAddr
338 BasicBlock *CondBlock =
339 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
340 Builder.SetInsertPoint(InsertPt);
342 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
343 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
344 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
346 // Create "else" block, fill it in the next iteration
347 BasicBlock *NewIfBlock =
348 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
349 Builder.SetInsertPoint(InsertPt);
350 Instruction *OldBr = IfBlock->getTerminator();
351 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
352 OldBr->eraseFromParent();
353 IfBlock = NewIfBlock;
355 CI->eraseFromParent();
357 ModifiedDT = true;
360 // Translate a masked gather intrinsic like
361 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
362 // <16 x i1> %Mask, <16 x i32> %Src)
363 // to a chain of basic blocks, with loading element one-by-one if
364 // the appropriate mask bit is set
366 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
367 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
368 // br i1 %Mask0, label %cond.load, label %else
370 // cond.load:
371 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
372 // %Load0 = load i32, i32* %Ptr0, align 4
373 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
374 // br label %else
376 // else:
377 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
378 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
379 // br i1 %Mask1, label %cond.load1, label %else2
381 // cond.load1:
382 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
383 // %Load1 = load i32, i32* %Ptr1, align 4
384 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
385 // br label %else2
386 // . . .
387 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
388 // ret <16 x i32> %Result
389 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
390 Value *Ptrs = CI->getArgOperand(0);
391 Value *Alignment = CI->getArgOperand(1);
392 Value *Mask = CI->getArgOperand(2);
393 Value *Src0 = CI->getArgOperand(3);
395 VectorType *VecType = cast<VectorType>(CI->getType());
396 Type *EltTy = VecType->getElementType();
398 IRBuilder<> Builder(CI->getContext());
399 Instruction *InsertPt = CI;
400 BasicBlock *IfBlock = CI->getParent();
401 Builder.SetInsertPoint(InsertPt);
402 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
404 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
406 // The result vector
407 Value *VResult = Src0;
408 unsigned VectorWidth = VecType->getNumElements();
410 // Shorten the way if the mask is a vector of constants.
411 if (isConstantIntVector(Mask)) {
412 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
413 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
414 continue;
415 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
416 LoadInst *Load =
417 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
418 VResult =
419 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
421 CI->replaceAllUsesWith(VResult);
422 CI->eraseFromParent();
423 return;
426 // If the mask is not v1i1, use scalar bit test operations. This generates
427 // better results on X86 at least.
428 Value *SclrMask;
429 if (VectorWidth != 1) {
430 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
431 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
434 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
435 // Fill the "else" block, created in the previous iteration
437 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
438 // %cond = icmp ne i16 %mask_1, 0
439 // br i1 %Mask1, label %cond.load, label %else
442 Value *Predicate;
443 if (VectorWidth != 1) {
444 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
445 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
446 Builder.getIntN(VectorWidth, 0));
447 } else {
448 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
451 // Create "cond" block
453 // %EltAddr = getelementptr i32* %1, i32 0
454 // %Elt = load i32* %EltAddr
455 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
457 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
458 Builder.SetInsertPoint(InsertPt);
460 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
461 LoadInst *Load =
462 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
463 Value *NewVResult =
464 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
466 // Create "else" block, fill it in the next iteration
467 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468 Builder.SetInsertPoint(InsertPt);
469 Instruction *OldBr = IfBlock->getTerminator();
470 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
471 OldBr->eraseFromParent();
472 BasicBlock *PrevIfBlock = IfBlock;
473 IfBlock = NewIfBlock;
475 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
476 Phi->addIncoming(NewVResult, CondBlock);
477 Phi->addIncoming(VResult, PrevIfBlock);
478 VResult = Phi;
481 CI->replaceAllUsesWith(VResult);
482 CI->eraseFromParent();
484 ModifiedDT = true;
487 // Translate a masked scatter intrinsic, like
488 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
489 // <16 x i1> %Mask)
490 // to a chain of basic blocks, that stores element one-by-one if
491 // the appropriate mask bit is set.
493 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
494 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
495 // br i1 %Mask0, label %cond.store, label %else
497 // cond.store:
498 // %Elt0 = extractelement <16 x i32> %Src, i32 0
499 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
500 // store i32 %Elt0, i32* %Ptr0, align 4
501 // br label %else
503 // else:
504 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
505 // br i1 %Mask1, label %cond.store1, label %else2
507 // cond.store1:
508 // %Elt1 = extractelement <16 x i32> %Src, i32 1
509 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
510 // store i32 %Elt1, i32* %Ptr1, align 4
511 // br label %else2
512 // . . .
513 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
514 Value *Src = CI->getArgOperand(0);
515 Value *Ptrs = CI->getArgOperand(1);
516 Value *Alignment = CI->getArgOperand(2);
517 Value *Mask = CI->getArgOperand(3);
519 assert(isa<VectorType>(Src->getType()) &&
520 "Unexpected data type in masked scatter intrinsic");
521 assert(isa<VectorType>(Ptrs->getType()) &&
522 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
523 "Vector of pointers is expected in masked scatter intrinsic");
525 IRBuilder<> Builder(CI->getContext());
526 Instruction *InsertPt = CI;
527 BasicBlock *IfBlock = CI->getParent();
528 Builder.SetInsertPoint(InsertPt);
529 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
531 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
532 unsigned VectorWidth = Src->getType()->getVectorNumElements();
534 // Shorten the way if the mask is a vector of constants.
535 if (isConstantIntVector(Mask)) {
536 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
537 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
538 continue;
539 Value *OneElt =
540 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
541 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
542 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
544 CI->eraseFromParent();
545 return;
548 // If the mask is not v1i1, use scalar bit test operations. This generates
549 // better results on X86 at least.
550 Value *SclrMask;
551 if (VectorWidth != 1) {
552 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
553 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
556 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
557 // Fill the "else" block, created in the previous iteration
559 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
560 // %cond = icmp ne i16 %mask_1, 0
561 // br i1 %Mask1, label %cond.store, label %else
563 Value *Predicate;
564 if (VectorWidth != 1) {
565 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
566 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
567 Builder.getIntN(VectorWidth, 0));
568 } else {
569 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
572 // Create "cond" block
574 // %Elt1 = extractelement <16 x i32> %Src, i32 1
575 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
576 // %store i32 %Elt1, i32* %Ptr1
578 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
579 Builder.SetInsertPoint(InsertPt);
581 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
582 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
583 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
585 // Create "else" block, fill it in the next iteration
586 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
587 Builder.SetInsertPoint(InsertPt);
588 Instruction *OldBr = IfBlock->getTerminator();
589 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
590 OldBr->eraseFromParent();
591 IfBlock = NewIfBlock;
593 CI->eraseFromParent();
595 ModifiedDT = true;
598 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
599 Value *Ptr = CI->getArgOperand(0);
600 Value *Mask = CI->getArgOperand(1);
601 Value *PassThru = CI->getArgOperand(2);
603 VectorType *VecType = cast<VectorType>(CI->getType());
605 Type *EltTy = VecType->getElementType();
607 IRBuilder<> Builder(CI->getContext());
608 Instruction *InsertPt = CI;
609 BasicBlock *IfBlock = CI->getParent();
611 Builder.SetInsertPoint(InsertPt);
612 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
614 unsigned VectorWidth = VecType->getNumElements();
616 // The result vector
617 Value *VResult = PassThru;
619 // Shorten the way if the mask is a vector of constants.
620 if (isConstantIntVector(Mask)) {
621 unsigned MemIndex = 0;
622 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
623 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
624 continue;
625 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
626 LoadInst *Load =
627 Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
628 VResult =
629 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
630 ++MemIndex;
632 CI->replaceAllUsesWith(VResult);
633 CI->eraseFromParent();
634 return;
637 // If the mask is not v1i1, use scalar bit test operations. This generates
638 // better results on X86 at least.
639 Value *SclrMask;
640 if (VectorWidth != 1) {
641 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
642 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
645 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
646 // Fill the "else" block, created in the previous iteration
648 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
649 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
650 // br i1 %mask_1, label %cond.load, label %else
653 Value *Predicate;
654 if (VectorWidth != 1) {
655 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
656 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
657 Builder.getIntN(VectorWidth, 0));
658 } else {
659 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
662 // Create "cond" block
664 // %EltAddr = getelementptr i32* %1, i32 0
665 // %Elt = load i32* %EltAddr
666 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
668 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
669 "cond.load");
670 Builder.SetInsertPoint(InsertPt);
672 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
673 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
675 // Move the pointer if there are more blocks to come.
676 Value *NewPtr;
677 if ((Idx + 1) != VectorWidth)
678 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
680 // Create "else" block, fill it in the next iteration
681 BasicBlock *NewIfBlock =
682 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
683 Builder.SetInsertPoint(InsertPt);
684 Instruction *OldBr = IfBlock->getTerminator();
685 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
686 OldBr->eraseFromParent();
687 BasicBlock *PrevIfBlock = IfBlock;
688 IfBlock = NewIfBlock;
690 // Create the phi to join the new and previous value.
691 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
692 ResultPhi->addIncoming(NewVResult, CondBlock);
693 ResultPhi->addIncoming(VResult, PrevIfBlock);
694 VResult = ResultPhi;
696 // Add a PHI for the pointer if this isn't the last iteration.
697 if ((Idx + 1) != VectorWidth) {
698 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
699 PtrPhi->addIncoming(NewPtr, CondBlock);
700 PtrPhi->addIncoming(Ptr, PrevIfBlock);
701 Ptr = PtrPhi;
705 CI->replaceAllUsesWith(VResult);
706 CI->eraseFromParent();
708 ModifiedDT = true;
711 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
712 Value *Src = CI->getArgOperand(0);
713 Value *Ptr = CI->getArgOperand(1);
714 Value *Mask = CI->getArgOperand(2);
716 VectorType *VecType = cast<VectorType>(Src->getType());
718 IRBuilder<> Builder(CI->getContext());
719 Instruction *InsertPt = CI;
720 BasicBlock *IfBlock = CI->getParent();
722 Builder.SetInsertPoint(InsertPt);
723 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
725 Type *EltTy = VecType->getVectorElementType();
727 unsigned VectorWidth = VecType->getNumElements();
729 // Shorten the way if the mask is a vector of constants.
730 if (isConstantIntVector(Mask)) {
731 unsigned MemIndex = 0;
732 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
733 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
734 continue;
735 Value *OneElt =
736 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
737 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
738 Builder.CreateAlignedStore(OneElt, NewPtr, 1);
739 ++MemIndex;
741 CI->eraseFromParent();
742 return;
745 // If the mask is not v1i1, use scalar bit test operations. This generates
746 // better results on X86 at least.
747 Value *SclrMask;
748 if (VectorWidth != 1) {
749 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
750 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
753 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
754 // Fill the "else" block, created in the previous iteration
756 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
757 // br i1 %mask_1, label %cond.store, label %else
759 Value *Predicate;
760 if (VectorWidth != 1) {
761 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
762 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
763 Builder.getIntN(VectorWidth, 0));
764 } else {
765 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
768 // Create "cond" block
770 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
771 // %EltAddr = getelementptr i32* %1, i32 0
772 // %store i32 %OneElt, i32* %EltAddr
774 BasicBlock *CondBlock =
775 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
776 Builder.SetInsertPoint(InsertPt);
778 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
779 Builder.CreateAlignedStore(OneElt, Ptr, 1);
781 // Move the pointer if there are more blocks to come.
782 Value *NewPtr;
783 if ((Idx + 1) != VectorWidth)
784 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
786 // Create "else" block, fill it in the next iteration
787 BasicBlock *NewIfBlock =
788 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
789 Builder.SetInsertPoint(InsertPt);
790 Instruction *OldBr = IfBlock->getTerminator();
791 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
792 OldBr->eraseFromParent();
793 BasicBlock *PrevIfBlock = IfBlock;
794 IfBlock = NewIfBlock;
796 // Add a PHI for the pointer if this isn't the last iteration.
797 if ((Idx + 1) != VectorWidth) {
798 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
799 PtrPhi->addIncoming(NewPtr, CondBlock);
800 PtrPhi->addIncoming(Ptr, PrevIfBlock);
801 Ptr = PtrPhi;
804 CI->eraseFromParent();
806 ModifiedDT = true;
809 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
810 bool EverMadeChange = false;
812 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
814 bool MadeChange = true;
815 while (MadeChange) {
816 MadeChange = false;
817 for (Function::iterator I = F.begin(); I != F.end();) {
818 BasicBlock *BB = &*I++;
819 bool ModifiedDTOnIteration = false;
820 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
822 // Restart BB iteration if the dominator tree of the Function was changed
823 if (ModifiedDTOnIteration)
824 break;
827 EverMadeChange |= MadeChange;
830 return EverMadeChange;
833 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
834 bool MadeChange = false;
836 BasicBlock::iterator CurInstIterator = BB.begin();
837 while (CurInstIterator != BB.end()) {
838 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
839 MadeChange |= optimizeCallInst(CI, ModifiedDT);
840 if (ModifiedDT)
841 return true;
844 return MadeChange;
847 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
848 bool &ModifiedDT) {
849 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
850 if (II) {
851 switch (II->getIntrinsicID()) {
852 default:
853 break;
854 case Intrinsic::masked_load:
855 // Scalarize unsupported vector masked load
856 if (TTI->isLegalMaskedLoad(CI->getType()))
857 return false;
858 scalarizeMaskedLoad(CI, ModifiedDT);
859 return true;
860 case Intrinsic::masked_store:
861 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
862 return false;
863 scalarizeMaskedStore(CI, ModifiedDT);
864 return true;
865 case Intrinsic::masked_gather:
866 if (TTI->isLegalMaskedGather(CI->getType()))
867 return false;
868 scalarizeMaskedGather(CI, ModifiedDT);
869 return true;
870 case Intrinsic::masked_scatter:
871 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
872 return false;
873 scalarizeMaskedScatter(CI, ModifiedDT);
874 return true;
875 case Intrinsic::masked_expandload:
876 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
877 return false;
878 scalarizeMaskedExpandLoad(CI, ModifiedDT);
879 return true;
880 case Intrinsic::masked_compressstore:
881 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
882 return false;
883 scalarizeMaskedCompressStore(CI, ModifiedDT);
884 return true;
888 return false;