[MLIR] Prevent invalid IR from being passed outside of RemoveDeadValues (#121079)
[llvm-project.git] / llvm / lib / Target / DirectX / DXILFinalizeLinkage.cpp
blob91ac758150fb4ca1cf81571a48cad7a44a1d65a5
1 //===- DXILFinalizeLinkage.cpp - Finalize linkage of functions ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "DXILFinalizeLinkage.h"
10 #include "DirectX.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/GlobalValue.h"
13 #include "llvm/IR/Metadata.h"
14 #include "llvm/IR/Module.h"
16 #define DEBUG_TYPE "dxil-finalize-linkage"
18 using namespace llvm;
20 static bool finalizeLinkage(Module &M) {
21 SmallPtrSet<Function *, 8> Funcs;
23 // Collect non-entry and non-exported functions to set to internal linkage.
24 for (Function &EF : M.functions()) {
25 if (EF.isIntrinsic())
26 continue;
27 if (EF.hasFnAttribute("hlsl.shader") || EF.hasFnAttribute("hlsl.export"))
28 continue;
29 Funcs.insert(&EF);
32 for (Function *F : Funcs) {
33 if (F->getLinkage() == GlobalValue::ExternalLinkage)
34 F->setLinkage(GlobalValue::InternalLinkage);
35 if (F->isDefTriviallyDead())
36 M.getFunctionList().erase(F);
39 return false;
42 PreservedAnalyses DXILFinalizeLinkage::run(Module &M,
43 ModuleAnalysisManager &AM) {
44 if (finalizeLinkage(M))
45 return PreservedAnalyses::none();
46 return PreservedAnalyses::all();
49 bool DXILFinalizeLinkageLegacy::runOnModule(Module &M) {
50 return finalizeLinkage(M);
53 char DXILFinalizeLinkageLegacy::ID = 0;
55 INITIALIZE_PASS_BEGIN(DXILFinalizeLinkageLegacy, DEBUG_TYPE,
56 "DXIL Finalize Linkage", false, false)
57 INITIALIZE_PASS_END(DXILFinalizeLinkageLegacy, DEBUG_TYPE,
58 "DXIL Finalize Linkage", false, false)
60 ModulePass *llvm::createDXILFinalizeLinkageLegacyPass() {
61 return new DXILFinalizeLinkageLegacy();