[MLIR] Prevent invalid IR from being passed outside of RemoveDeadValues (#121079)
[llvm-project.git] / llvm / unittests / Target / X86 / TernlogTest.cpp
blobedb4431a05a6800d643e8ec77da080cb95377265
1 //===- LICMTest.cpp - LICM unit tests -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Analysis/TargetTransformInfo.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/MC/TargetRegistry.h"
13 #include "llvm/Passes/PassBuilder.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/InstCombine/InstCombine.h"
18 #include "gtest/gtest.h"
20 #include <random>
22 namespace llvm {
23 static std::unique_ptr<TargetMachine> initTM() {
24 LLVMInitializeX86TargetInfo();
25 LLVMInitializeX86Target();
26 LLVMInitializeX86TargetMC();
28 auto TT(Triple::normalize("x86_64--"));
29 std::string Error;
30 const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error);
31 return std::unique_ptr<TargetMachine>(
32 TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt,
33 std::nullopt, CodeGenOptLevel::Default));
36 struct TernTester {
37 unsigned NElem;
38 unsigned ElemWidth;
39 std::mt19937_64 Rng;
40 unsigned ImmVal;
41 SmallVector<uint64_t, 16> VecElems[3];
43 void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; }
44 void updateNElem(unsigned NewNElem) {
45 NElem = NewNElem;
46 for (unsigned I = 0; I < 3; ++I) {
47 VecElems[I].resize(NElem);
50 void updateElemWidth(unsigned NewElemWidth) {
51 ElemWidth = NewElemWidth;
52 assert(ElemWidth == 32 || ElemWidth == 64);
55 uint64_t getElemMask() const {
56 return (~uint64_t(0)) >> ((ElemWidth - 0) % 64);
59 void RandomizeVecArgs() {
60 uint64_t ElemMask = getElemMask();
61 for (unsigned I = 0; I < 3; ++I) {
62 for (unsigned J = 0; J < NElem; ++J) {
63 VecElems[I][J] = Rng() & ElemMask;
68 std::pair<std::string, std::string> getScalarInfo() const {
69 switch (ElemWidth) {
70 case 32:
71 return {"i32", "d"};
72 case 64:
73 return {"i64", "q"};
74 default:
75 llvm_unreachable("Invalid ElemWidth");
78 std::string getScalarType() const { return getScalarInfo().first; }
79 std::string getScalarExt() const { return getScalarInfo().second; }
80 std::string getVecType() const {
81 return "<" + Twine(NElem).str() + " x " + getScalarType() + ">";
84 std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); }
85 std::string getFunctionName() const {
86 return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth();
88 std::string getFunctionDecl() const {
89 return "declare " + getVecType() + getFunctionName() + "(" + getVecType() +
90 ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)";
93 std::string getVecN(unsigned N) const {
94 assert(N < 3);
95 std::string VecStr = getVecType() + " <";
96 for (unsigned I = 0; I < VecElems[N].size(); ++I) {
97 if (I != 0)
98 VecStr += ", ";
99 VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str();
101 return VecStr + ">";
103 std::string getFunctionCall() const {
104 return "tail call " + getVecType() + " " + getFunctionName() + "(" +
105 getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " +
106 Twine(ImmVal).str() + ")";
109 std::string getTestText() const {
110 return getFunctionDecl() + "\ndefine " + getVecType() +
111 "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() +
112 " %r\n}\n";
115 void checkResult(const Value *V) {
116 auto GetValElem = [&](unsigned Idx) -> uint64_t {
117 if (auto *CV = dyn_cast<ConstantDataVector>(V))
118 return CV->getElementAsInteger(Idx);
120 auto *C = dyn_cast<Constant>(V);
121 assert(C);
122 if (C->isNullValue())
123 return 0;
124 if (C->isAllOnesValue())
125 return ((~uint64_t(0)) >> (ElemWidth % 64));
126 if (C->isOneValue())
127 return 1;
129 llvm_unreachable("Unknown constant type");
132 auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t {
133 unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1);
134 return (ImmVal >> BitIdx) & 1;
137 for (unsigned I = 0; I < NElem; ++I) {
138 uint64_t Expec = 0;
139 uint64_t AEle = VecElems[0][I];
140 uint64_t BEle = VecElems[1][I];
141 uint64_t CEle = VecElems[2][I];
142 for (unsigned J = 0; J < ElemWidth; ++J) {
143 Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J;
146 ASSERT_EQ(Expec, GetValElem(I));
150 void check(LLVMContext &Ctx, FunctionPassManager &FPM,
151 FunctionAnalysisManager &FAM) {
152 SMDiagnostic Error;
153 std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx);
154 ASSERT_TRUE(M);
155 Function *F = M->getFunction("foo");
156 ASSERT_TRUE(F);
157 ASSERT_EQ(F->getInstructionCount(), 2u);
158 FAM.clear();
159 FPM.run(*F, FAM);
160 ASSERT_EQ(F->getInstructionCount(), 1u);
161 ASSERT_EQ(F->size(), 1u);
162 const Instruction *I = F->begin()->getTerminator();
163 ASSERT_TRUE(I);
164 ASSERT_EQ(I->getNumOperands(), 1u);
165 checkResult(I->getOperand(0));
169 TEST(TernlogTest, TestConstantFolding) {
170 LLVMContext Ctx;
171 FunctionAnalysisManager FAM;
172 FunctionPassManager FPM;
173 PassBuilder PB;
174 LoopAnalysisManager LAM;
175 CGSCCAnalysisManager CGAM;
176 ModuleAnalysisManager MAM;
177 TargetIRAnalysis TIRA = TargetIRAnalysis(
178 [&](const Function &F) { return initTM()->getTargetTransformInfo(F); });
180 FAM.registerPass([&] { return TIRA; });
181 PB.registerModuleAnalyses(MAM);
182 PB.registerCGSCCAnalyses(CGAM);
183 PB.registerFunctionAnalyses(FAM);
184 PB.registerLoopAnalyses(LAM);
185 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
187 FPM.addPass(InstCombinePass());
188 TernTester TT;
189 for (unsigned NElem = 2; NElem < 16; NElem += NElem) {
190 TT.updateNElem(NElem);
191 for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) {
192 if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128)
193 continue;
194 TT.updateElemWidth(ElemWidth);
195 TT.RandomizeVecArgs();
196 for (unsigned Imm = 0; Imm < 256; ++Imm) {
197 TT.updateImm(Imm);
198 TT.check(Ctx, FPM, FAM);
203 } // namespace llvm