1 //===- LICMTest.cpp - LICM unit tests -------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/TargetTransformInfo.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/MC/TargetRegistry.h"
13 #include "llvm/Passes/PassBuilder.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/InstCombine/InstCombine.h"
18 #include "gtest/gtest.h"
23 static std::unique_ptr
<TargetMachine
> initTM() {
24 LLVMInitializeX86TargetInfo();
25 LLVMInitializeX86Target();
26 LLVMInitializeX86TargetMC();
28 auto TT(Triple::normalize("x86_64--"));
30 const Target
*TheTarget
= TargetRegistry::lookupTarget(TT
, Error
);
31 return std::unique_ptr
<TargetMachine
>(
32 TheTarget
->createTargetMachine(TT
, "", "", TargetOptions(), std::nullopt
,
33 std::nullopt
, CodeGenOptLevel::Default
));
41 SmallVector
<uint64_t, 16> VecElems
[3];
43 void updateImm(uint8_t NewImmVal
) { ImmVal
= NewImmVal
; }
44 void updateNElem(unsigned NewNElem
) {
46 for (unsigned I
= 0; I
< 3; ++I
) {
47 VecElems
[I
].resize(NElem
);
50 void updateElemWidth(unsigned NewElemWidth
) {
51 ElemWidth
= NewElemWidth
;
52 assert(ElemWidth
== 32 || ElemWidth
== 64);
55 uint64_t getElemMask() const {
56 return (~uint64_t(0)) >> ((ElemWidth
- 0) % 64);
59 void RandomizeVecArgs() {
60 uint64_t ElemMask
= getElemMask();
61 for (unsigned I
= 0; I
< 3; ++I
) {
62 for (unsigned J
= 0; J
< NElem
; ++J
) {
63 VecElems
[I
][J
] = Rng() & ElemMask
;
68 std::pair
<std::string
, std::string
> getScalarInfo() const {
75 llvm_unreachable("Invalid ElemWidth");
78 std::string
getScalarType() const { return getScalarInfo().first
; }
79 std::string
getScalarExt() const { return getScalarInfo().second
; }
80 std::string
getVecType() const {
81 return "<" + Twine(NElem
).str() + " x " + getScalarType() + ">";
84 std::string
getVecWidth() const { return Twine(NElem
* ElemWidth
).str(); }
85 std::string
getFunctionName() const {
86 return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth();
88 std::string
getFunctionDecl() const {
89 return "declare " + getVecType() + getFunctionName() + "(" + getVecType() +
90 ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)";
93 std::string
getVecN(unsigned N
) const {
95 std::string VecStr
= getVecType() + " <";
96 for (unsigned I
= 0; I
< VecElems
[N
].size(); ++I
) {
99 VecStr
+= getScalarType() + " " + Twine(VecElems
[N
][I
]).str();
103 std::string
getFunctionCall() const {
104 return "tail call " + getVecType() + " " + getFunctionName() + "(" +
105 getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " +
106 Twine(ImmVal
).str() + ")";
109 std::string
getTestText() const {
110 return getFunctionDecl() + "\ndefine " + getVecType() +
111 "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() +
115 void checkResult(const Value
*V
) {
116 auto GetValElem
= [&](unsigned Idx
) -> uint64_t {
117 if (auto *CV
= dyn_cast
<ConstantDataVector
>(V
))
118 return CV
->getElementAsInteger(Idx
);
120 auto *C
= dyn_cast
<Constant
>(V
);
122 if (C
->isNullValue())
124 if (C
->isAllOnesValue())
125 return ((~uint64_t(0)) >> (ElemWidth
% 64));
129 llvm_unreachable("Unknown constant type");
132 auto ComputeBit
= [&](uint64_t A
, uint64_t B
, uint64_t C
) -> uint64_t {
133 unsigned BitIdx
= ((A
& 1) << 2) | ((B
& 1) << 1) | (C
& 1);
134 return (ImmVal
>> BitIdx
) & 1;
137 for (unsigned I
= 0; I
< NElem
; ++I
) {
139 uint64_t AEle
= VecElems
[0][I
];
140 uint64_t BEle
= VecElems
[1][I
];
141 uint64_t CEle
= VecElems
[2][I
];
142 for (unsigned J
= 0; J
< ElemWidth
; ++J
) {
143 Expec
|= ComputeBit(AEle
>> J
, BEle
>> J
, CEle
>> J
) << J
;
146 ASSERT_EQ(Expec
, GetValElem(I
));
150 void check(LLVMContext
&Ctx
, FunctionPassManager
&FPM
,
151 FunctionAnalysisManager
&FAM
) {
153 std::unique_ptr
<Module
> M
= parseAssemblyString(getTestText(), Error
, Ctx
);
155 Function
*F
= M
->getFunction("foo");
157 ASSERT_EQ(F
->getInstructionCount(), 2u);
160 ASSERT_EQ(F
->getInstructionCount(), 1u);
161 ASSERT_EQ(F
->size(), 1u);
162 const Instruction
*I
= F
->begin()->getTerminator();
164 ASSERT_EQ(I
->getNumOperands(), 1u);
165 checkResult(I
->getOperand(0));
169 TEST(TernlogTest
, TestConstantFolding
) {
171 FunctionAnalysisManager FAM
;
172 FunctionPassManager FPM
;
174 LoopAnalysisManager LAM
;
175 CGSCCAnalysisManager CGAM
;
176 ModuleAnalysisManager MAM
;
177 TargetIRAnalysis TIRA
= TargetIRAnalysis(
178 [&](const Function
&F
) { return initTM()->getTargetTransformInfo(F
); });
180 FAM
.registerPass([&] { return TIRA
; });
181 PB
.registerModuleAnalyses(MAM
);
182 PB
.registerCGSCCAnalyses(CGAM
);
183 PB
.registerFunctionAnalyses(FAM
);
184 PB
.registerLoopAnalyses(LAM
);
185 PB
.crossRegisterProxies(LAM
, FAM
, CGAM
, MAM
);
187 FPM
.addPass(InstCombinePass());
189 for (unsigned NElem
= 2; NElem
< 16; NElem
+= NElem
) {
190 TT
.updateNElem(NElem
);
191 for (unsigned ElemWidth
= 32; ElemWidth
<= 64; ElemWidth
+= ElemWidth
) {
192 if (ElemWidth
* NElem
> 512 || ElemWidth
* NElem
< 128)
194 TT
.updateElemWidth(ElemWidth
);
195 TT
.RandomizeVecArgs();
196 for (unsigned Imm
= 0; Imm
< 256; ++Imm
) {
198 TT
.check(Ctx
, FPM
, FAM
);