Use a smart pointer to auto-free some WinAPI memory
[openal-soft.git] / router / router.cpp
blobe9dfb9a08853834b0ba2cab372e534adae19deae
2 #include "config.h"
4 #include "router.h"
6 #include <algorithm>
7 #include <array>
8 #include <cstdio>
9 #include <cstdlib>
10 #include <cstring>
11 #include <string>
12 #include <string_view>
13 #include <vector>
15 #include "AL/alc.h"
16 #include "AL/al.h"
18 #include "albit.h"
19 #include "alstring.h"
20 #include "opthelpers.h"
21 #include "strutils.h"
23 #include "version.h"
26 eLogLevel LogLevel{eLogLevel::Error};
27 gsl::owner<std::FILE*> LogFile;
29 namespace {
31 std::vector<std::wstring> gAcceptList;
32 std::vector<std::wstring> gRejectList;
35 void AddModule(HMODULE module, const std::wstring_view name)
37 for(auto &drv : DriverList)
39 if(drv->Module == module)
41 TRACE("Skipping already-loaded module %p\n", decltype(std::declval<void*>()){module});
42 FreeLibrary(module);
43 return;
45 if(drv->Name == name)
47 TRACE("Skipping similarly-named module %.*ls\n", al::sizei(name), name.data());
48 FreeLibrary(module);
49 return;
52 if(!gAcceptList.empty())
54 auto iter = std::find_if(gAcceptList.cbegin(), gAcceptList.cend(),
55 [name](const std::wstring_view accept)
56 { return al::case_compare(name, accept) == 0; });
57 if(iter == gAcceptList.cend())
59 TRACE("%.*ls not found in ALROUTER_ACCEPT, skipping\n", al::sizei(name), name.data());
60 FreeLibrary(module);
61 return;
64 if(!gRejectList.empty())
66 auto iter = std::find_if(gRejectList.cbegin(), gRejectList.cend(),
67 [name](const std::wstring_view accept)
68 { return al::case_compare(name, accept) == 0; });
69 if(iter != gRejectList.cend())
71 TRACE("%.*ls found in ALROUTER_REJECT, skipping\n", al::sizei(name), name.data());
72 FreeLibrary(module);
73 return;
77 DriverIface &newdrv = *DriverList.emplace_back(std::make_unique<DriverIface>(name, module));
79 /* Load required functions. */
80 bool loadok{true};
81 auto do_load = [module,name](auto &func, const char *fname) -> bool
83 using func_t = std::remove_reference_t<decltype(func)>;
84 auto ptr = GetProcAddress(module, fname);
85 if(!ptr)
87 ERR("Failed to find entry point for %s in %.*ls\n", fname, al::sizei(name),
88 name.data());
89 return false;
92 func = al::bit_cast<func_t>(ptr);
93 return true;
95 #define LOAD_PROC(x) loadok &= do_load(newdrv.x, #x)
96 LOAD_PROC(alcCreateContext);
97 LOAD_PROC(alcMakeContextCurrent);
98 LOAD_PROC(alcProcessContext);
99 LOAD_PROC(alcSuspendContext);
100 LOAD_PROC(alcDestroyContext);
101 LOAD_PROC(alcGetCurrentContext);
102 LOAD_PROC(alcGetContextsDevice);
103 LOAD_PROC(alcOpenDevice);
104 LOAD_PROC(alcCloseDevice);
105 LOAD_PROC(alcGetError);
106 LOAD_PROC(alcIsExtensionPresent);
107 LOAD_PROC(alcGetProcAddress);
108 LOAD_PROC(alcGetEnumValue);
109 LOAD_PROC(alcGetString);
110 LOAD_PROC(alcGetIntegerv);
111 LOAD_PROC(alcCaptureOpenDevice);
112 LOAD_PROC(alcCaptureCloseDevice);
113 LOAD_PROC(alcCaptureStart);
114 LOAD_PROC(alcCaptureStop);
115 LOAD_PROC(alcCaptureSamples);
117 LOAD_PROC(alEnable);
118 LOAD_PROC(alDisable);
119 LOAD_PROC(alIsEnabled);
120 LOAD_PROC(alGetString);
121 LOAD_PROC(alGetBooleanv);
122 LOAD_PROC(alGetIntegerv);
123 LOAD_PROC(alGetFloatv);
124 LOAD_PROC(alGetDoublev);
125 LOAD_PROC(alGetBoolean);
126 LOAD_PROC(alGetInteger);
127 LOAD_PROC(alGetFloat);
128 LOAD_PROC(alGetDouble);
129 LOAD_PROC(alGetError);
130 LOAD_PROC(alIsExtensionPresent);
131 LOAD_PROC(alGetProcAddress);
132 LOAD_PROC(alGetEnumValue);
133 LOAD_PROC(alListenerf);
134 LOAD_PROC(alListener3f);
135 LOAD_PROC(alListenerfv);
136 LOAD_PROC(alListeneri);
137 LOAD_PROC(alListener3i);
138 LOAD_PROC(alListeneriv);
139 LOAD_PROC(alGetListenerf);
140 LOAD_PROC(alGetListener3f);
141 LOAD_PROC(alGetListenerfv);
142 LOAD_PROC(alGetListeneri);
143 LOAD_PROC(alGetListener3i);
144 LOAD_PROC(alGetListeneriv);
145 LOAD_PROC(alGenSources);
146 LOAD_PROC(alDeleteSources);
147 LOAD_PROC(alIsSource);
148 LOAD_PROC(alSourcef);
149 LOAD_PROC(alSource3f);
150 LOAD_PROC(alSourcefv);
151 LOAD_PROC(alSourcei);
152 LOAD_PROC(alSource3i);
153 LOAD_PROC(alSourceiv);
154 LOAD_PROC(alGetSourcef);
155 LOAD_PROC(alGetSource3f);
156 LOAD_PROC(alGetSourcefv);
157 LOAD_PROC(alGetSourcei);
158 LOAD_PROC(alGetSource3i);
159 LOAD_PROC(alGetSourceiv);
160 LOAD_PROC(alSourcePlayv);
161 LOAD_PROC(alSourceStopv);
162 LOAD_PROC(alSourceRewindv);
163 LOAD_PROC(alSourcePausev);
164 LOAD_PROC(alSourcePlay);
165 LOAD_PROC(alSourceStop);
166 LOAD_PROC(alSourceRewind);
167 LOAD_PROC(alSourcePause);
168 LOAD_PROC(alSourceQueueBuffers);
169 LOAD_PROC(alSourceUnqueueBuffers);
170 LOAD_PROC(alGenBuffers);
171 LOAD_PROC(alDeleteBuffers);
172 LOAD_PROC(alIsBuffer);
173 LOAD_PROC(alBufferData);
174 LOAD_PROC(alDopplerFactor);
175 LOAD_PROC(alDopplerVelocity);
176 LOAD_PROC(alSpeedOfSound);
177 LOAD_PROC(alDistanceModel);
178 #undef LOAD_PROC
179 if(loadok)
181 std::array<ALCint,2> alc_ver{0, 0};
182 newdrv.alcGetIntegerv(nullptr, ALC_MAJOR_VERSION, 1, &alc_ver[0]);
183 newdrv.alcGetIntegerv(nullptr, ALC_MINOR_VERSION, 1, &alc_ver[1]);
184 if(newdrv.alcGetError(nullptr) == ALC_NO_ERROR)
185 newdrv.ALCVer = MAKE_ALC_VER(alc_ver[0], alc_ver[1]);
186 else
188 WARN("Failed to query ALC version for %.*ls, assuming 1.0\n", al::sizei(name),
189 name.data());
190 newdrv.ALCVer = MAKE_ALC_VER(1, 0);
193 auto do_load2 = [module,name](auto &func, const char *fname) -> void
195 using func_t = std::remove_reference_t<decltype(func)>;
196 auto ptr = GetProcAddress(module, fname);
197 if(!ptr)
198 WARN("Failed to find optional entry point for %s in %.*ls\n", fname,
199 al::sizei(name), name.data());
200 else
201 func = al::bit_cast<func_t>(ptr);
203 #define LOAD_PROC(x) do_load2(newdrv.x, #x)
204 LOAD_PROC(alBufferf);
205 LOAD_PROC(alBuffer3f);
206 LOAD_PROC(alBufferfv);
207 LOAD_PROC(alBufferi);
208 LOAD_PROC(alBuffer3i);
209 LOAD_PROC(alBufferiv);
210 LOAD_PROC(alGetBufferf);
211 LOAD_PROC(alGetBuffer3f);
212 LOAD_PROC(alGetBufferfv);
213 LOAD_PROC(alGetBufferi);
214 LOAD_PROC(alGetBuffer3i);
215 LOAD_PROC(alGetBufferiv);
216 #undef LOAD_PROC
218 auto do_load3 = [name,&newdrv](auto &func, const char *fname) -> bool
220 using func_t = std::remove_reference_t<decltype(func)>;
221 auto ptr = newdrv.alcGetProcAddress(nullptr, fname);
222 if(!ptr)
224 ERR("Failed to find entry point for %s in %.*ls\n", fname, al::sizei(name),
225 name.data());
226 return false;
229 func = reinterpret_cast<func_t>(ptr);
230 return true;
232 #define LOAD_PROC(x) loadok &= do_load3(newdrv.x, #x)
233 if(newdrv.alcIsExtensionPresent(nullptr, "ALC_EXT_thread_local_context"))
235 LOAD_PROC(alcSetThreadContext);
236 LOAD_PROC(alcGetThreadContext);
238 #undef LOAD_PROC
241 if(!loadok)
243 DriverList.pop_back();
244 return;
246 TRACE("Loaded module %p, %.*ls, ALC %d.%d\n", decltype(std::declval<void*>()){module},
247 al::sizei(name), name.data(), newdrv.ALCVer>>8, newdrv.ALCVer&255);
250 void SearchDrivers(const std::wstring_view path)
252 TRACE("Searching for drivers in %.*ls...\n", al::sizei(path), path.data());
253 std::wstring srchPath{path};
254 srchPath += L"\\*oal.dll";
256 WIN32_FIND_DATAW fdata{};
257 HANDLE srchHdl{FindFirstFileW(srchPath.c_str(), &fdata)};
258 if(srchHdl == INVALID_HANDLE_VALUE) return;
260 do {
261 srchPath = path;
262 srchPath += L"\\";
263 srchPath += std::data(fdata.cFileName);
264 TRACE("Found %ls\n", srchPath.c_str());
266 HMODULE mod{LoadLibraryW(srchPath.c_str())};
267 if(!mod)
268 WARN("Could not load %ls\n", srchPath.c_str());
269 else
270 AddModule(mod, std::data(fdata.cFileName));
271 } while(FindNextFileW(srchHdl, &fdata));
272 FindClose(srchHdl);
275 bool GetLoadedModuleDirectory(const WCHAR *name, std::wstring *moddir)
277 HMODULE module{nullptr};
279 if(name)
281 module = GetModuleHandleW(name);
282 if(!module) return false;
285 moddir->assign(256, '\0');
286 DWORD res{GetModuleFileNameW(module, moddir->data(), static_cast<DWORD>(moddir->size()))};
287 if(res >= moddir->size())
289 do {
290 moddir->append(256, '\0');
291 res = GetModuleFileNameW(module, moddir->data(), static_cast<DWORD>(moddir->size()));
292 } while(res >= moddir->size());
294 moddir->resize(res);
296 auto sep0 = moddir->rfind('/');
297 auto sep1 = moddir->rfind('\\');
298 if(sep0 < moddir->size() && sep1 < moddir->size())
299 moddir->resize(std::max(sep0, sep1));
300 else if(sep0 < moddir->size())
301 moddir->resize(sep0);
302 else if(sep1 < moddir->size())
303 moddir->resize(sep1);
304 else
305 moddir->resize(0);
307 return !moddir->empty();
310 void LoadDriverList()
312 if(auto list = al::getenv(L"ALROUTER_ACCEPT"))
314 std::wstring_view namelist{*list};
315 while(!namelist.empty())
317 auto seppos = namelist.find(',');
318 if(seppos > 0)
319 gAcceptList.emplace_back(namelist.substr(0, seppos));
320 if(seppos < namelist.size())
321 namelist.remove_prefix(seppos+1);
322 else
323 namelist.remove_prefix(namelist.size());
326 if(auto list = al::getenv(L"ALROUTER_REJECT"))
328 std::wstring_view namelist{*list};
329 while(!namelist.empty())
331 auto seppos = namelist.find(',');
332 if(seppos > 0)
333 gRejectList.emplace_back(namelist.substr(0, seppos));
334 if(seppos < namelist.size())
335 namelist.remove_prefix(seppos+1);
336 else
337 namelist.remove_prefix(namelist.size());
341 std::wstring dll_path;
342 if(GetLoadedModuleDirectory(L"OpenAL32.dll", &dll_path))
343 TRACE("Got DLL path %ls\n", dll_path.c_str());
345 std::wstring cwd_path;
346 if(DWORD pathlen{GetCurrentDirectoryW(0, nullptr)})
348 do {
349 cwd_path.resize(pathlen);
350 pathlen = GetCurrentDirectoryW(pathlen, cwd_path.data());
351 } while(pathlen >= cwd_path.size());
352 cwd_path.resize(pathlen);
354 if(!cwd_path.empty() && (cwd_path.back() == '\\' || cwd_path.back() == '/'))
355 cwd_path.pop_back();
356 if(!cwd_path.empty())
357 TRACE("Got current working directory %ls\n", cwd_path.c_str());
359 std::wstring proc_path;
360 if(GetLoadedModuleDirectory(nullptr, &proc_path))
361 TRACE("Got proc path %ls\n", proc_path.c_str());
363 std::wstring sys_path;
364 if(UINT pathlen{GetSystemDirectoryW(nullptr, 0)})
366 do {
367 sys_path.resize(pathlen);
368 pathlen = GetSystemDirectoryW(sys_path.data(), pathlen);
369 } while(pathlen >= sys_path.size());
370 sys_path.resize(pathlen);
372 if(!sys_path.empty() && (sys_path.back() == '\\' || sys_path.back() == '/'))
373 sys_path.pop_back();
374 if(!sys_path.empty())
375 TRACE("Got system path %ls\n", sys_path.c_str());
377 /* Don't search the DLL's path if it is the same as the current working
378 * directory, app's path, or system path (don't want to do duplicate
379 * searches, or increase the priority of the app or system path).
381 if(!dll_path.empty() &&
382 (cwd_path.empty() || dll_path != cwd_path) &&
383 (proc_path.empty() || dll_path != proc_path) &&
384 (sys_path.empty() || dll_path != sys_path))
385 SearchDrivers(dll_path);
386 if(!cwd_path.empty() &&
387 (proc_path.empty() || cwd_path != proc_path) &&
388 (sys_path.empty() || cwd_path != sys_path))
389 SearchDrivers(cwd_path);
390 if(!proc_path.empty() && (sys_path.empty() || proc_path != sys_path))
391 SearchDrivers(proc_path);
392 if(!sys_path.empty())
393 SearchDrivers(sys_path);
396 } // namespace
398 BOOL APIENTRY DllMain(HINSTANCE, DWORD reason, void*)
400 switch(reason)
402 case DLL_PROCESS_ATTACH:
403 if(auto logfname = al::getenv("ALROUTER_LOGFILE"))
405 gsl::owner<std::FILE*> f{fopen(logfname->c_str(), "w")};
406 if(f == nullptr)
407 ERR("Could not open log file: %s\n", logfname->c_str());
408 else
409 LogFile = f;
411 if(auto loglev = al::getenv("ALROUTER_LOGLEVEL"))
413 char *end = nullptr;
414 long l{strtol(loglev->c_str(), &end, 0)};
415 if(!end || *end != '\0')
416 ERR("Invalid log level value: %s\n", loglev->c_str());
417 else if(l < al::to_underlying(eLogLevel::None)
418 || l > al::to_underlying(eLogLevel::Trace))
419 ERR("Log level out of range: %s\n", loglev->c_str());
420 else
421 LogLevel = static_cast<eLogLevel>(l);
423 TRACE("Initializing router v0.1-%s %s\n", ALSOFT_GIT_COMMIT_HASH, ALSOFT_GIT_BRANCH);
424 LoadDriverList();
426 break;
428 case DLL_THREAD_ATTACH:
429 break;
430 case DLL_THREAD_DETACH:
431 break;
433 case DLL_PROCESS_DETACH:
434 DriverList.clear();
436 if(LogFile)
437 fclose(LogFile);
438 LogFile = nullptr;
440 break;
442 return TRUE;