[InstCombine] Remove insertRangeTest code that handles the equality case.
[llvm-complete.git] / lib / CodeGen / ScalarizeMaskedMemIntrin.cpp
blob7776dffb4e9c8ff7fa68fd1c1412c40c16207e2d
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
37 using namespace llvm;
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 namespace {
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44 const TargetTransformInfo *TTI = nullptr;
46 public:
47 static char ID; // Pass identification, replacement for typeid
49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
50 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function &F) override;
55 StringRef getPassName() const override {
56 return "Scalarize Masked Memory Intrinsics";
59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<TargetTransformInfoWrapperPass>();
63 private:
64 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
68 } // end anonymous namespace
70 char ScalarizeMaskedMemIntrin::ID = 0;
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73 "Scalarize unsupported masked memory intrinsics", false, false)
75 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
79 static bool isConstantIntVector(Value *Mask) {
80 Constant *C = dyn_cast<Constant>(Mask);
81 if (!C)
82 return false;
84 unsigned NumElts = Mask->getType()->getVectorNumElements();
85 for (unsigned i = 0; i != NumElts; ++i) {
86 Constant *CElt = C->getAggregateElement(i);
87 if (!CElt || !isa<ConstantInt>(CElt))
88 return false;
91 return true;
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127 Value *Ptr = CI->getArgOperand(0);
128 Value *Alignment = CI->getArgOperand(1);
129 Value *Mask = CI->getArgOperand(2);
130 Value *Src0 = CI->getArgOperand(3);
132 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133 VectorType *VecType = cast<VectorType>(CI->getType());
135 Type *EltTy = VecType->getElementType();
137 IRBuilder<> Builder(CI->getContext());
138 Instruction *InsertPt = CI;
139 BasicBlock *IfBlock = CI->getParent();
141 Builder.SetInsertPoint(InsertPt);
142 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144 // Short-cut if the mask is all-true.
145 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147 CI->replaceAllUsesWith(NewI);
148 CI->eraseFromParent();
149 return;
152 // Adjust alignment for the scalar instruction.
153 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154 // Bitcast %addr from i8* to EltTy*
155 Type *NewPtrType =
156 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158 unsigned VectorWidth = VecType->getNumElements();
160 // The result vector
161 Value *VResult = Src0;
163 if (isConstantIntVector(Mask)) {
164 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166 continue;
167 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
171 CI->replaceAllUsesWith(VResult);
172 CI->eraseFromParent();
173 return;
176 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
177 // Fill the "else" block, created in the previous iteration
179 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181 // br i1 %mask_1, label %cond.load, label %else
184 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
186 // Create "cond" block
188 // %EltAddr = getelementptr i32* %1, i32 0
189 // %Elt = load i32* %EltAddr
190 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
192 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193 "cond.load");
194 Builder.SetInsertPoint(InsertPt);
196 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
197 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
198 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
200 // Create "else" block, fill it in the next iteration
201 BasicBlock *NewIfBlock =
202 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203 Builder.SetInsertPoint(InsertPt);
204 Instruction *OldBr = IfBlock->getTerminator();
205 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
206 OldBr->eraseFromParent();
207 BasicBlock *PrevIfBlock = IfBlock;
208 IfBlock = NewIfBlock;
210 // Create the phi to join the new and previous value.
211 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212 Phi->addIncoming(NewVResult, CondBlock);
213 Phi->addIncoming(VResult, PrevIfBlock);
214 VResult = Phi;
217 CI->replaceAllUsesWith(VResult);
218 CI->eraseFromParent();
220 ModifiedDT = true;
223 // Translate a masked store intrinsic, like
224 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225 // <16 x i1> %mask)
226 // to a chain of basic blocks, that stores element one-by-one if
227 // the appropriate mask bit is set
229 // %1 = bitcast i8* %addr to i32*
230 // %2 = extractelement <16 x i1> %mask, i32 0
231 // br i1 %2, label %cond.store, label %else
233 // cond.store: ; preds = %0
234 // %3 = extractelement <16 x i32> %val, i32 0
235 // %4 = getelementptr i32* %1, i32 0
236 // store i32 %3, i32* %4
237 // br label %else
239 // else: ; preds = %0, %cond.store
240 // %5 = extractelement <16 x i1> %mask, i32 1
241 // br i1 %5, label %cond.store1, label %else2
243 // cond.store1: ; preds = %else
244 // %6 = extractelement <16 x i32> %val, i32 1
245 // %7 = getelementptr i32* %1, i32 1
246 // store i32 %6, i32* %7
247 // br label %else2
248 // . . .
249 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
250 Value *Src = CI->getArgOperand(0);
251 Value *Ptr = CI->getArgOperand(1);
252 Value *Alignment = CI->getArgOperand(2);
253 Value *Mask = CI->getArgOperand(3);
255 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
256 VectorType *VecType = cast<VectorType>(Src->getType());
258 Type *EltTy = VecType->getElementType();
260 IRBuilder<> Builder(CI->getContext());
261 Instruction *InsertPt = CI;
262 BasicBlock *IfBlock = CI->getParent();
263 Builder.SetInsertPoint(InsertPt);
264 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
266 // Short-cut if the mask is all-true.
267 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
268 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269 CI->eraseFromParent();
270 return;
273 // Adjust alignment for the scalar instruction.
274 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
275 // Bitcast %addr from i8* to EltTy*
276 Type *NewPtrType =
277 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
278 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279 unsigned VectorWidth = VecType->getNumElements();
281 if (isConstantIntVector(Mask)) {
282 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
284 continue;
285 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
287 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
289 CI->eraseFromParent();
290 return;
293 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
294 // Fill the "else" block, created in the previous iteration
296 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297 // br i1 %mask_1, label %cond.store, label %else
299 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
301 // Create "cond" block
303 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
304 // %EltAddr = getelementptr i32* %1, i32 0
305 // %store i32 %OneElt, i32* %EltAddr
307 BasicBlock *CondBlock =
308 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309 Builder.SetInsertPoint(InsertPt);
311 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
313 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
315 // Create "else" block, fill it in the next iteration
316 BasicBlock *NewIfBlock =
317 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318 Builder.SetInsertPoint(InsertPt);
319 Instruction *OldBr = IfBlock->getTerminator();
320 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
321 OldBr->eraseFromParent();
322 IfBlock = NewIfBlock;
324 CI->eraseFromParent();
326 ModifiedDT = true;
329 // Translate a masked gather intrinsic like
330 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331 // <16 x i1> %Mask, <16 x i32> %Src)
332 // to a chain of basic blocks, with loading element one-by-one if
333 // the appropriate mask bit is set
335 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
337 // br i1 %Mask0, label %cond.load, label %else
339 // cond.load:
340 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341 // %Load0 = load i32, i32* %Ptr0, align 4
342 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
343 // br label %else
345 // else:
346 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
348 // br i1 %Mask1, label %cond.load1, label %else2
350 // cond.load1:
351 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352 // %Load1 = load i32, i32* %Ptr1, align 4
353 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
354 // br label %else2
355 // . . .
356 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357 // ret <16 x i32> %Result
358 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
359 Value *Ptrs = CI->getArgOperand(0);
360 Value *Alignment = CI->getArgOperand(1);
361 Value *Mask = CI->getArgOperand(2);
362 Value *Src0 = CI->getArgOperand(3);
364 VectorType *VecType = cast<VectorType>(CI->getType());
365 Type *EltTy = VecType->getElementType();
367 IRBuilder<> Builder(CI->getContext());
368 Instruction *InsertPt = CI;
369 BasicBlock *IfBlock = CI->getParent();
370 Builder.SetInsertPoint(InsertPt);
371 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
373 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
375 // The result vector
376 Value *VResult = Src0;
377 unsigned VectorWidth = VecType->getNumElements();
379 // Shorten the way if the mask is a vector of constants.
380 if (isConstantIntVector(Mask)) {
381 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
382 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
383 continue;
384 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
385 LoadInst *Load =
386 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
387 VResult =
388 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
390 CI->replaceAllUsesWith(VResult);
391 CI->eraseFromParent();
392 return;
395 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396 // Fill the "else" block, created in the previous iteration
398 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
399 // br i1 %Mask1, label %cond.load, label %else
402 Value *Predicate =
403 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
405 // Create "cond" block
407 // %EltAddr = getelementptr i32* %1, i32 0
408 // %Elt = load i32* %EltAddr
409 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
411 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
412 Builder.SetInsertPoint(InsertPt);
414 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
415 LoadInst *Load =
416 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
417 Value *NewVResult =
418 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
420 // Create "else" block, fill it in the next iteration
421 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422 Builder.SetInsertPoint(InsertPt);
423 Instruction *OldBr = IfBlock->getTerminator();
424 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
425 OldBr->eraseFromParent();
426 BasicBlock *PrevIfBlock = IfBlock;
427 IfBlock = NewIfBlock;
429 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430 Phi->addIncoming(NewVResult, CondBlock);
431 Phi->addIncoming(VResult, PrevIfBlock);
432 VResult = Phi;
435 CI->replaceAllUsesWith(VResult);
436 CI->eraseFromParent();
438 ModifiedDT = true;
441 // Translate a masked scatter intrinsic, like
442 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443 // <16 x i1> %Mask)
444 // to a chain of basic blocks, that stores element one-by-one if
445 // the appropriate mask bit is set.
447 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
449 // br i1 %Mask0, label %cond.store, label %else
451 // cond.store:
452 // %Elt0 = extractelement <16 x i32> %Src, i32 0
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // store i32 %Elt0, i32* %Ptr0, align 4
455 // br label %else
457 // else:
458 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
459 // br i1 %Mask1, label %cond.store1, label %else2
461 // cond.store1:
462 // %Elt1 = extractelement <16 x i32> %Src, i32 1
463 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464 // store i32 %Elt1, i32* %Ptr1, align 4
465 // br label %else2
466 // . . .
467 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
468 Value *Src = CI->getArgOperand(0);
469 Value *Ptrs = CI->getArgOperand(1);
470 Value *Alignment = CI->getArgOperand(2);
471 Value *Mask = CI->getArgOperand(3);
473 assert(isa<VectorType>(Src->getType()) &&
474 "Unexpected data type in masked scatter intrinsic");
475 assert(isa<VectorType>(Ptrs->getType()) &&
476 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477 "Vector of pointers is expected in masked scatter intrinsic");
479 IRBuilder<> Builder(CI->getContext());
480 Instruction *InsertPt = CI;
481 BasicBlock *IfBlock = CI->getParent();
482 Builder.SetInsertPoint(InsertPt);
483 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
485 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486 unsigned VectorWidth = Src->getType()->getVectorNumElements();
488 // Shorten the way if the mask is a vector of constants.
489 if (isConstantIntVector(Mask)) {
490 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
491 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
492 continue;
493 Value *OneElt =
494 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
496 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
498 CI->eraseFromParent();
499 return;
502 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503 // Fill the "else" block, created in the previous iteration
505 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506 // br i1 %Mask1, label %cond.store, label %else
508 Value *Predicate =
509 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
511 // Create "cond" block
513 // %Elt1 = extractelement <16 x i32> %Src, i32 1
514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515 // %store i32 %Elt1, i32* %Ptr1
517 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518 Builder.SetInsertPoint(InsertPt);
520 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
522 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
524 // Create "else" block, fill it in the next iteration
525 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526 Builder.SetInsertPoint(InsertPt);
527 Instruction *OldBr = IfBlock->getTerminator();
528 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
529 OldBr->eraseFromParent();
530 IfBlock = NewIfBlock;
532 CI->eraseFromParent();
534 ModifiedDT = true;
537 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538 Value *Ptr = CI->getArgOperand(0);
539 Value *Mask = CI->getArgOperand(1);
540 Value *PassThru = CI->getArgOperand(2);
542 VectorType *VecType = cast<VectorType>(CI->getType());
544 Type *EltTy = VecType->getElementType();
546 IRBuilder<> Builder(CI->getContext());
547 Instruction *InsertPt = CI;
548 BasicBlock *IfBlock = CI->getParent();
550 Builder.SetInsertPoint(InsertPt);
551 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
553 unsigned VectorWidth = VecType->getNumElements();
555 // The result vector
556 Value *VResult = PassThru;
558 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
559 // Fill the "else" block, created in the previous iteration
561 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563 // br i1 %mask_1, label %cond.load, label %else
566 Value *Predicate =
567 Builder.CreateExtractElement(Mask, Idx);
569 // Create "cond" block
571 // %EltAddr = getelementptr i32* %1, i32 0
572 // %Elt = load i32* %EltAddr
573 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
575 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576 "cond.load");
577 Builder.SetInsertPoint(InsertPt);
579 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
582 // Move the pointer if there are more blocks to come.
583 Value *NewPtr;
584 if ((Idx + 1) != VectorWidth)
585 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
587 // Create "else" block, fill it in the next iteration
588 BasicBlock *NewIfBlock =
589 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590 Builder.SetInsertPoint(InsertPt);
591 Instruction *OldBr = IfBlock->getTerminator();
592 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593 OldBr->eraseFromParent();
594 BasicBlock *PrevIfBlock = IfBlock;
595 IfBlock = NewIfBlock;
597 // Create the phi to join the new and previous value.
598 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599 ResultPhi->addIncoming(NewVResult, CondBlock);
600 ResultPhi->addIncoming(VResult, PrevIfBlock);
601 VResult = ResultPhi;
603 // Add a PHI for the pointer if this isn't the last iteration.
604 if ((Idx + 1) != VectorWidth) {
605 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606 PtrPhi->addIncoming(NewPtr, CondBlock);
607 PtrPhi->addIncoming(Ptr, PrevIfBlock);
608 Ptr = PtrPhi;
612 CI->replaceAllUsesWith(VResult);
613 CI->eraseFromParent();
615 ModifiedDT = true;
618 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619 Value *Src = CI->getArgOperand(0);
620 Value *Ptr = CI->getArgOperand(1);
621 Value *Mask = CI->getArgOperand(2);
623 VectorType *VecType = cast<VectorType>(Src->getType());
625 IRBuilder<> Builder(CI->getContext());
626 Instruction *InsertPt = CI;
627 BasicBlock *IfBlock = CI->getParent();
629 Builder.SetInsertPoint(InsertPt);
630 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
632 Type *EltTy = VecType->getVectorElementType();
634 unsigned VectorWidth = VecType->getNumElements();
636 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
637 // Fill the "else" block, created in the previous iteration
639 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640 // br i1 %mask_1, label %cond.store, label %else
642 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
644 // Create "cond" block
646 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
647 // %EltAddr = getelementptr i32* %1, i32 0
648 // %store i32 %OneElt, i32* %EltAddr
650 BasicBlock *CondBlock =
651 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652 Builder.SetInsertPoint(InsertPt);
654 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655 Builder.CreateAlignedStore(OneElt, Ptr, 1);
657 // Move the pointer if there are more blocks to come.
658 Value *NewPtr;
659 if ((Idx + 1) != VectorWidth)
660 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
662 // Create "else" block, fill it in the next iteration
663 BasicBlock *NewIfBlock =
664 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665 Builder.SetInsertPoint(InsertPt);
666 Instruction *OldBr = IfBlock->getTerminator();
667 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668 OldBr->eraseFromParent();
669 BasicBlock *PrevIfBlock = IfBlock;
670 IfBlock = NewIfBlock;
672 // Add a PHI for the pointer if this isn't the last iteration.
673 if ((Idx + 1) != VectorWidth) {
674 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675 PtrPhi->addIncoming(NewPtr, CondBlock);
676 PtrPhi->addIncoming(Ptr, PrevIfBlock);
677 Ptr = PtrPhi;
680 CI->eraseFromParent();
682 ModifiedDT = true;
685 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
686 bool EverMadeChange = false;
688 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
690 bool MadeChange = true;
691 while (MadeChange) {
692 MadeChange = false;
693 for (Function::iterator I = F.begin(); I != F.end();) {
694 BasicBlock *BB = &*I++;
695 bool ModifiedDTOnIteration = false;
696 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
698 // Restart BB iteration if the dominator tree of the Function was changed
699 if (ModifiedDTOnIteration)
700 break;
703 EverMadeChange |= MadeChange;
706 return EverMadeChange;
709 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710 bool MadeChange = false;
712 BasicBlock::iterator CurInstIterator = BB.begin();
713 while (CurInstIterator != BB.end()) {
714 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715 MadeChange |= optimizeCallInst(CI, ModifiedDT);
716 if (ModifiedDT)
717 return true;
720 return MadeChange;
723 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724 bool &ModifiedDT) {
725 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
726 if (II) {
727 switch (II->getIntrinsicID()) {
728 default:
729 break;
730 case Intrinsic::masked_load:
731 // Scalarize unsupported vector masked load
732 if (TTI->isLegalMaskedLoad(CI->getType()))
733 return false;
734 scalarizeMaskedLoad(CI, ModifiedDT);
735 return true;
736 case Intrinsic::masked_store:
737 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738 return false;
739 scalarizeMaskedStore(CI, ModifiedDT);
740 return true;
741 case Intrinsic::masked_gather:
742 if (TTI->isLegalMaskedGather(CI->getType()))
743 return false;
744 scalarizeMaskedGather(CI, ModifiedDT);
745 return true;
746 case Intrinsic::masked_scatter:
747 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748 return false;
749 scalarizeMaskedScatter(CI, ModifiedDT);
750 return true;
751 case Intrinsic::masked_expandload:
752 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753 return false;
754 scalarizeMaskedExpandLoad(CI, ModifiedDT);
755 return true;
756 case Intrinsic::masked_compressstore:
757 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758 return false;
759 scalarizeMaskedCompressStore(CI, ModifiedDT);
760 return true;
764 return false;