1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome_frame/dll_redirector.h"
9 #include <atlsecurity.h>
12 #include "base/file_path.h"
13 #include "base/file_version_info.h"
14 #include "base/logging.h"
15 #include "base/path_service.h"
16 #include "base/shared_memory.h"
17 #include "base/string_util.h"
18 #include "base/utf_string_conversions.h"
19 #include "base/version.h"
20 #include "base/win/windows_version.h"
21 #include "chrome_frame/utils.h"
23 const wchar_t kSharedMemoryName
[] = L
"ChromeFrameVersionBeacon_";
24 const uint32 kSharedMemorySize
= 128;
25 const uint32 kSharedMemoryLockTimeoutMs
= 1000;
28 DllRedirector::DllRedirector() : first_module_handle_(NULL
) {
29 // TODO(robertshield): Allow for overrides to be taken from the environment.
30 std::wstring
beacon_name(kSharedMemoryName
);
31 beacon_name
+= GetHostProcessName(false);
32 shared_memory_
.reset(new base::SharedMemory(beacon_name
));
33 shared_memory_name_
= WideToUTF8(beacon_name
);
36 DllRedirector::DllRedirector(const char* shared_memory_name
)
37 : shared_memory_name_(shared_memory_name
), first_module_handle_(NULL
) {
38 shared_memory_
.reset(new base::SharedMemory(ASCIIToWide(shared_memory_name
)));
41 DllRedirector::~DllRedirector() {
42 if (first_module_handle_
) {
43 if (first_module_handle_
!= reinterpret_cast<HMODULE
>(&__ImageBase
)) {
44 FreeLibrary(first_module_handle_
);
46 NOTREACHED() << "Error, DllRedirector attempting to free self.";
49 first_module_handle_
= NULL
;
51 UnregisterAsFirstCFModule();
55 DllRedirector
* DllRedirector::GetInstance() {
56 return Singleton
<DllRedirector
>::get();
59 bool DllRedirector::BuildSecurityAttributesForLock(
60 CSecurityAttributes
* sec_attr
) {
62 if (base::win::GetVersion() < base::win::VERSION_VISTA
) {
63 // Don't bother with changing ACLs on pre-vista.
69 // Fill out the rest of the security descriptor from the process token.
71 if (token
.GetProcessToken(TOKEN_QUERY
)) {
72 CSecurityDesc security_desc
;
73 // Set the SACL from an SDDL string that allows access to low-integrity
74 // processes. See http://msdn.microsoft.com/en-us/library/bb625958.aspx.
75 if (security_desc
.FromString(L
"S:(ML;;NW;;;LW)")) {
77 if (token
.GetOwner(&sid_owner
)) {
78 security_desc
.SetOwner(sid_owner
);
80 NOTREACHED() << "Could not get owner.";
83 if (token
.GetPrimaryGroup(&sid_group
)) {
84 security_desc
.SetGroup(sid_group
);
86 NOTREACHED() << "Could not get group.";
89 if (token
.GetDefaultDacl(&dacl
)) {
90 // Add an access control entry mask for the current user.
91 // This is what grants this user access from lower integrity levels.
93 if (token
.GetUser(&sid_user
)) {
94 success
= dacl
.AddAllowedAce(sid_user
, MUTEX_ALL_ACCESS
);
95 security_desc
.SetDacl(dacl
);
96 sec_attr
->Set(security_desc
);
105 bool DllRedirector::SetFileMappingToReadOnly(base::SharedMemoryHandle mapping
) {
106 bool success
= false;
109 if (token
.GetProcessToken(TOKEN_QUERY
)) {
111 if (token
.GetUser(&sid_user
)) {
113 dacl
.AddAllowedAce(sid_user
, STANDARD_RIGHTS_READ
| FILE_MAP_READ
);
114 success
= AtlSetDacl(mapping
, SE_KERNEL_OBJECT
, dacl
);
122 bool DllRedirector::RegisterAsFirstCFModule() {
123 DCHECK(first_module_handle_
== NULL
);
125 // Build our own file version outside of the lock:
126 scoped_ptr
<Version
> our_version(GetCurrentModuleVersion());
128 // We sadly can't use the autolock here since we want to have a timeout.
129 // Be careful not to return while holding the lock. Also, attempt to do as
130 // little as possible while under this lock.
132 bool lock_acquired
= false;
133 CSecurityAttributes sec_attr
;
134 if (base::win::GetVersion() >= base::win::VERSION_VISTA
&&
135 BuildSecurityAttributesForLock(&sec_attr
)) {
136 // On vista and above, we need to explicitly allow low integrity access
137 // to our objects. On XP, we don't bother.
138 lock_acquired
= shared_memory_
->Lock(kSharedMemoryLockTimeoutMs
, &sec_attr
);
140 lock_acquired
= shared_memory_
->Lock(kSharedMemoryLockTimeoutMs
, NULL
);
143 if (!lock_acquired
) {
144 // We couldn't get the lock in a reasonable amount of time, so fall
145 // back to loading our current version. We return true to indicate that the
146 // caller should not attempt to delegate to an already loaded version.
147 dll_version_
.swap(our_version
);
151 bool created_beacon
= true;
152 bool result
= shared_memory_
->CreateNamed(shared_memory_name_
.c_str(),
153 false, // open_existing
157 // We created the beacon, now we need to mutate the security attributes
158 // on the shared memory to allow read-only access and let low-integrity
159 // processes open it. This will fail on FAT32 file systems.
160 if (!SetFileMappingToReadOnly(shared_memory_
->handle())) {
161 DLOG(ERROR
) << "Failed to set file mapping permissions.";
164 created_beacon
= false;
166 // We failed to create the shared memory segment, suggesting it may already
167 // exist: try to create it read-only.
168 result
= shared_memory_
->Open(shared_memory_name_
.c_str(),
169 true /* read_only */);
173 // Map in the whole thing.
174 result
= shared_memory_
->Map(0);
175 DCHECK(shared_memory_
->memory());
178 // Either write our own version number or read it in if it was already
179 // present in the shared memory section.
180 if (created_beacon
) {
181 dll_version_
.swap(our_version
);
183 lstrcpynA(reinterpret_cast<char*>(shared_memory_
->memory()),
184 dll_version_
->GetString().c_str(),
185 std::min(kSharedMemorySize
,
186 dll_version_
->GetString().length() + 1));
188 char buffer
[kSharedMemorySize
] = {0};
189 memcpy(buffer
, shared_memory_
->memory(), kSharedMemorySize
- 1);
190 dll_version_
.reset(Version::GetVersionFromString(buffer
));
192 if (!dll_version_
.get() || dll_version_
->Equals(*our_version
.get())) {
193 // If we either couldn't parse a valid version out of the shared
194 // memory or we did parse a version and it is the same as our own,
195 // then pretend we're first in to avoid trying to load any other DLLs.
196 dll_version_
.reset(our_version
.release());
197 created_beacon
= true;
201 NOTREACHED() << "Failed to map in version beacon.";
204 NOTREACHED() << "Could not create file mapping for version beacon, gle: "
209 shared_memory_
->Unlock();
211 return created_beacon
;
214 void DllRedirector::UnregisterAsFirstCFModule() {
215 if (base::SharedMemory::IsHandleValid(shared_memory_
->handle())) {
216 bool lock_acquired
= shared_memory_
->Lock(kSharedMemoryLockTimeoutMs
, NULL
);
218 // Free our handles. The last closed handle SHOULD result in it being
220 shared_memory_
->Close();
221 shared_memory_
->Unlock();
226 LPFNGETCLASSOBJECT
DllRedirector::GetDllGetClassObjectPtr() {
227 HMODULE first_module_handle
= GetFirstModule();
229 LPFNGETCLASSOBJECT proc_ptr
= NULL
;
230 if (first_module_handle
) {
231 proc_ptr
= reinterpret_cast<LPFNGETCLASSOBJECT
>(
232 GetProcAddress(first_module_handle
, "DllGetClassObject"));
233 DPLOG_IF(ERROR
, !proc_ptr
) << "DllRedirector: Could not get address of "
234 "DllGetClassObject from first loaded module.";
240 Version
* DllRedirector::GetCurrentModuleVersion() {
241 scoped_ptr
<FileVersionInfo
> file_version_info(
242 FileVersionInfo::CreateFileVersionInfoForCurrentModule());
243 DCHECK(file_version_info
.get());
245 Version
* current_version
= NULL
;
246 if (file_version_info
.get()) {
247 current_version
= Version::GetVersionFromString(
248 WideToASCII(file_version_info
->file_version()));
249 DCHECK(current_version
);
252 return current_version
;
255 HMODULE
DllRedirector::GetFirstModule() {
256 DCHECK(dll_version_
.get())
257 << "Error: Did you call RegisterAsFirstCFModule() first?";
259 if (first_module_handle_
== NULL
) {
260 first_module_handle_
= LoadVersionedModule(dll_version_
.get());
263 if (first_module_handle_
== reinterpret_cast<HMODULE
>(&__ImageBase
)) {
264 NOTREACHED() << "Should not be loading own version.";
265 first_module_handle_
= NULL
;
268 return first_module_handle_
;
271 HMODULE
DllRedirector::LoadVersionedModule(Version
* version
) {
274 FilePath module_path
;
275 PathService::Get(base::FILE_MODULE
, &module_path
);
276 DCHECK(!module_path
.empty());
278 // For a module located in
279 // Foo\XXXXXXXXX\<module>.dll, load
280 // Foo\<version>\<module>.dll:
281 FilePath module_name
= module_path
.BaseName();
282 module_path
= module_path
.DirName()
284 .Append(ASCIIToWide(version
->GetString()))
285 .Append(module_name
);
287 HMODULE hmodule
= LoadLibrary(module_path
.value().c_str());
288 if (hmodule
== NULL
) {
289 DPLOG(ERROR
) << "Could not load reported module version "
290 << version
->GetString();