Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / remoting / host / win / chromoting_module.cc
blob566418f28ed6c5e7bfcb4cc354b5e42c508ca55b
1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/win/chromoting_module.h"
7 #include "base/lazy_instance.h"
8 #include "base/logging.h"
9 #include "base/message_loop/message_loop.h"
10 #include "base/run_loop.h"
11 #include "base/strings/utf_string_conversions.h"
12 #include "base/win/scoped_handle.h"
13 #include "base/win/windows_version.h"
14 #include "remoting/base/auto_thread_task_runner.h"
15 #include "remoting/base/typed_buffer.h"
16 #include "remoting/host/host_exit_codes.h"
17 #include "remoting/host/win/rdp_desktop_session.h"
19 namespace remoting {
21 namespace {
23 // Holds a reference to the task runner used by the module.
24 base::LazyInstance<scoped_refptr<AutoThreadTaskRunner> > g_module_task_runner =
25 LAZY_INSTANCE_INITIALIZER;
27 // Lowers the process integrity level such that it does not exceed |max_level|.
28 // |max_level| is expected to be one of SECURITY_MANDATORY_XXX constants.
29 bool LowerProcessIntegrityLevel(DWORD max_level) {
30 HANDLE temp_handle;
31 if (!OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY | TOKEN_WRITE,
32 &temp_handle)) {
33 PLOG(ERROR) << "OpenProcessToken() failed";
34 return false;
36 base::win::ScopedHandle token(temp_handle);
38 TypedBuffer<TOKEN_MANDATORY_LABEL> mandatory_label;
39 DWORD length = 0;
41 // Get the size of the buffer needed to hold the mandatory label.
42 BOOL result = GetTokenInformation(token.Get(), TokenIntegrityLevel,
43 mandatory_label.get(), length, &length);
44 if (!result && GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
45 // Allocate a buffer that is large enough.
46 TypedBuffer<TOKEN_MANDATORY_LABEL> buffer(length);
47 mandatory_label.Swap(buffer);
49 // Get the the mandatory label.
50 result = GetTokenInformation(token.Get(), TokenIntegrityLevel,
51 mandatory_label.get(), length, &length);
53 if (!result) {
54 PLOG(ERROR) << "Failed to get the mandatory label";
55 return false;
58 // Read the current integrity level.
59 DWORD sub_authority_count =
60 *GetSidSubAuthorityCount(mandatory_label->Label.Sid);
61 DWORD* current_level = GetSidSubAuthority(mandatory_label->Label.Sid,
62 sub_authority_count - 1);
64 // Set the integrity level to |max_level| if needed.
65 if (*current_level > max_level) {
66 *current_level = max_level;
67 if (!SetTokenInformation(token.Get(), TokenIntegrityLevel,
68 mandatory_label.get(), length)) {
69 PLOG(ERROR) << "Failed to set the mandatory label";
70 return false;
74 return true;
77 } // namespace
79 ChromotingModule::ChromotingModule(
80 ATL::_ATL_OBJMAP_ENTRY* classes,
81 ATL::_ATL_OBJMAP_ENTRY* classes_end)
82 : classes_(classes),
83 classes_end_(classes_end) {
84 // Don't do anything if COM initialization failed.
85 if (!com_initializer_.succeeded())
86 return;
88 ATL::_AtlComModule.ExecuteObjectMain(true);
91 ChromotingModule::~ChromotingModule() {
92 // Don't do anything if COM initialization failed.
93 if (!com_initializer_.succeeded())
94 return;
96 Term();
97 ATL::_AtlComModule.ExecuteObjectMain(false);
100 // static
101 scoped_refptr<AutoThreadTaskRunner> ChromotingModule::task_runner() {
102 return g_module_task_runner.Get();
105 bool ChromotingModule::Run() {
106 // Don't do anything if COM initialization failed.
107 if (!com_initializer_.succeeded())
108 return false;
110 // Register class objects.
111 HRESULT result = RegisterClassObjects(CLSCTX_LOCAL_SERVER,
112 REGCLS_MULTIPLEUSE | REGCLS_SUSPENDED);
113 if (FAILED(result)) {
114 LOG(ERROR) << "Failed to register class objects, result=0x"
115 << std::hex << result << std::dec << ".";
116 return false;
119 // Arrange to run |message_loop| until no components depend on it.
120 base::MessageLoopForUI message_loop;
121 base::RunLoop run_loop;
122 g_module_task_runner.Get() = new AutoThreadTaskRunner(
123 message_loop.task_runner(), run_loop.QuitClosure());
125 // Start accepting activations.
126 result = CoResumeClassObjects();
127 if (FAILED(result)) {
128 LOG(ERROR) << "CoResumeClassObjects() failed, result=0x"
129 << std::hex << result << std::dec << ".";
130 return false;
133 // Run the loop until the module lock counter reaches zero.
134 run_loop.Run();
136 // Unregister class objects.
137 result = RevokeClassObjects();
138 if (FAILED(result)) {
139 LOG(ERROR) << "Failed to unregister class objects, result=0x"
140 << std::hex << result << std::dec << ".";
141 return false;
144 return true;
147 LONG ChromotingModule::Unlock() {
148 LONG count = ATL::CAtlModuleT<ChromotingModule>::Unlock();
150 if (!count) {
151 // Stop accepting activations.
152 HRESULT hr = CoSuspendClassObjects();
153 CHECK(SUCCEEDED(hr));
155 // Release the message loop reference, causing the message loop to exit.
156 g_module_task_runner.Get() = nullptr;
159 return count;
162 HRESULT ChromotingModule::RegisterClassObjects(DWORD class_context,
163 DWORD flags) {
164 for (ATL::_ATL_OBJMAP_ENTRY* i = classes_; i != classes_end_; ++i) {
165 HRESULT result = i->RegisterClassObject(class_context, flags);
166 if (FAILED(result))
167 return result;
170 return S_OK;
173 HRESULT ChromotingModule::RevokeClassObjects() {
174 for (ATL::_ATL_OBJMAP_ENTRY* i = classes_; i != classes_end_; ++i) {
175 HRESULT result = i->RevokeClassObject();
176 if (FAILED(result))
177 return result;
180 return S_OK;
183 // RdpClient entry point.
184 int RdpDesktopSessionMain() {
185 // Lower the integrity level to medium, which is the lowest level at which
186 // the RDP ActiveX control can run.
187 if (base::win::GetVersion() >= base::win::VERSION_VISTA) {
188 if (!LowerProcessIntegrityLevel(SECURITY_MANDATORY_MEDIUM_RID))
189 return kInitializationFailed;
192 ATL::_ATL_OBJMAP_ENTRY rdp_client_entry[] = {
193 OBJECT_ENTRY(__uuidof(RdpDesktopSession), RdpDesktopSession)
196 ChromotingModule module(rdp_client_entry, rdp_client_entry + 1);
197 return module.Run() ? kSuccessExitCode : kInitializationFailed;
200 } // namespace remoting