1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/policy/cloud/test_request_interceptor.h"
11 #include "base/bind.h"
12 #include "base/bind_helpers.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/sequenced_task_runner.h"
16 #include "base/thread_task_runner_handle.h"
17 #include "content/public/browser/browser_thread.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/upload_bytes_element_reader.h"
20 #include "net/base/upload_data_stream.h"
21 #include "net/base/upload_element_reader.h"
22 #include "net/test/url_request/url_request_mock_http_job.h"
23 #include "net/url_request/url_request_error_job.h"
24 #include "net/url_request/url_request_filter.h"
25 #include "net/url_request/url_request_interceptor.h"
26 #include "net/url_request/url_request_test_job.h"
29 namespace em
= enterprise_management
;
35 // Helper callback for jobs that should fail with a network |error|.
36 net::URLRequestJob
* ErrorJobCallback(int error
,
37 net::URLRequest
* request
,
38 net::NetworkDelegate
* network_delegate
) {
39 return new net::URLRequestErrorJob(request
, network_delegate
, error
);
42 // Helper callback for jobs that should fail with a 400 HTTP error.
43 net::URLRequestJob
* BadRequestJobCallback(
44 net::URLRequest
* request
,
45 net::NetworkDelegate
* network_delegate
) {
46 static const char kBadHeaders
[] =
47 "HTTP/1.1 400 Bad request\0"
48 "Content-type: application/protobuf\0"
50 std::string
headers(kBadHeaders
, arraysize(kBadHeaders
));
51 return new net::URLRequestTestJob(
52 request
, network_delegate
, headers
, std::string(), true);
55 net::URLRequestJob
* FileJobCallback(const base::FilePath
& file_path
,
56 net::URLRequest
* request
,
57 net::NetworkDelegate
* network_delegate
) {
58 return new net::URLRequestMockHTTPJob(
62 content::BrowserThread::GetBlockingPool()
63 ->GetTaskRunnerWithShutdownBehavior(
64 base::SequencedWorkerPool::SKIP_ON_SHUTDOWN
));
67 // Parses the upload data in |request| into |request_msg|, and validates the
68 // request. The query string in the URL must contain the |expected_type| for
69 // the "request" parameter. Returns true if all checks succeeded, and the
70 // request data has been parsed into |request_msg|.
71 bool ValidRequest(net::URLRequest
* request
,
72 const std::string
& expected_type
,
73 em::DeviceManagementRequest
* request_msg
) {
74 if (request
->method() != "POST")
76 std::string spec
= request
->url().spec();
77 if (spec
.find("request=" + expected_type
) == std::string::npos
)
80 // This assumes that the payload data was set from a single string. In that
81 // case the UploadDataStream has a single UploadBytesElementReader with the
83 const net::UploadDataStream
* stream
= request
->get_upload();
86 const ScopedVector
<net::UploadElementReader
>* readers
=
87 stream
->GetElementReaders();
88 if (!readers
|| readers
->size() != 1u)
90 const net::UploadBytesElementReader
* reader
= (*readers
)[0]->AsBytesReader();
93 std::string
data(reader
->bytes(), reader
->length());
94 if (!request_msg
->ParseFromString(data
))
100 // Helper callback for register jobs that should suceed. Validates the request
101 // parameters and returns an appropriate response job. If |expect_reregister|
102 // is true then the reregister flag must be set in the DeviceRegisterRequest
104 net::URLRequestJob
* RegisterJobCallback(
105 em::DeviceRegisterRequest::Type expected_type
,
106 bool expect_reregister
,
107 net::URLRequest
* request
,
108 net::NetworkDelegate
* network_delegate
) {
109 em::DeviceManagementRequest request_msg
;
110 if (!ValidRequest(request
, "register", &request_msg
))
111 return BadRequestJobCallback(request
, network_delegate
);
113 if (!request_msg
.has_register_request() ||
114 request_msg
.has_unregister_request() ||
115 request_msg
.has_policy_request() ||
116 request_msg
.has_device_status_report_request() ||
117 request_msg
.has_session_status_report_request() ||
118 request_msg
.has_auto_enrollment_request()) {
119 return BadRequestJobCallback(request
, network_delegate
);
122 const em::DeviceRegisterRequest
& register_request
=
123 request_msg
.register_request();
124 if (expect_reregister
&&
125 (!register_request
.has_reregister() || !register_request
.reregister())) {
126 return BadRequestJobCallback(request
, network_delegate
);
127 } else if (!expect_reregister
&&
128 register_request
.has_reregister() &&
129 register_request
.reregister()) {
130 return BadRequestJobCallback(request
, network_delegate
);
133 if (!register_request
.has_type() || register_request
.type() != expected_type
)
134 return BadRequestJobCallback(request
, network_delegate
);
136 em::DeviceManagementResponse response
;
137 em::DeviceRegisterResponse
* register_response
=
138 response
.mutable_register_response();
139 register_response
->set_device_management_token("s3cr3t70k3n");
141 response
.SerializeToString(&data
);
143 static const char kGoodHeaders
[] =
145 "Content-type: application/protobuf\0"
147 std::string
headers(kGoodHeaders
, arraysize(kGoodHeaders
));
148 return new net::URLRequestTestJob(
149 request
, network_delegate
, headers
, data
, true);
152 void RegisterHttpInterceptor(
153 const std::string
& hostname
,
154 scoped_ptr
<net::URLRequestInterceptor
> interceptor
) {
155 net::URLRequestFilter::GetInstance()->AddHostnameInterceptor(
156 "http", hostname
, interceptor
.Pass());
161 class TestRequestInterceptor::Delegate
: public net::URLRequestInterceptor
{
163 Delegate(const std::string
& hostname
,
164 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
);
165 ~Delegate() override
;
167 // net::URLRequestInterceptor implementation:
168 net::URLRequestJob
* MaybeInterceptRequest(
169 net::URLRequest
* request
,
170 net::NetworkDelegate
* network_delegate
) const override
;
172 void GetPendingSize(size_t* pending_size
) const;
173 void AddRequestServicedCallback(const base::Closure
& callback
);
174 void PushJobCallback(const JobCallback
& callback
);
177 static void InvokeRequestServicedCallbacks(
178 scoped_ptr
<std::vector
<base::Closure
>> callbacks
);
180 const std::string hostname_
;
181 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner_
;
183 // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a
184 // const method; it can't reenter though, because it runs exclusively on
186 mutable std::queue
<JobCallback
> pending_job_callbacks_
;
188 // Queue of pending request serviced callbacks. Mutable for the same reason
189 // as |pending_job_callbacks_|.
190 mutable std::vector
<base::Closure
> request_serviced_callbacks_
;
193 TestRequestInterceptor::Delegate::Delegate(
194 const std::string
& hostname
,
195 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
)
196 : hostname_(hostname
), io_task_runner_(io_task_runner
) {}
198 TestRequestInterceptor::Delegate::~Delegate() {}
200 net::URLRequestJob
* TestRequestInterceptor::Delegate::MaybeInterceptRequest(
201 net::URLRequest
* request
,
202 net::NetworkDelegate
* network_delegate
) const {
203 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
205 if (request
->url().host() != hostname_
) {
206 // Reject requests to other servers.
207 return ErrorJobCallback(
208 net::ERR_CONNECTION_REFUSED
, request
, network_delegate
);
211 if (pending_job_callbacks_
.empty()) {
212 // Reject dmserver requests by default.
213 return BadRequestJobCallback(request
, network_delegate
);
216 // Invoke any callbacks that are waiting for the next request to be serviced
217 // after this job is serviced.
218 if (!request_serviced_callbacks_
.empty()) {
219 scoped_ptr
<std::vector
<base::Closure
>> callbacks(
220 new std::vector
<base::Closure
>);
221 callbacks
->swap(request_serviced_callbacks_
);
222 io_task_runner_
->PostTask(
223 FROM_HERE
, base::Bind(&Delegate::InvokeRequestServicedCallbacks
,
224 base::Passed(&callbacks
)));
227 JobCallback callback
= pending_job_callbacks_
.front();
228 pending_job_callbacks_
.pop();
229 return callback
.Run(request
, network_delegate
);
232 void TestRequestInterceptor::Delegate::GetPendingSize(
233 size_t* pending_size
) const {
234 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
235 *pending_size
= pending_job_callbacks_
.size();
238 void TestRequestInterceptor::Delegate::AddRequestServicedCallback(
239 const base::Closure
& callback
) {
240 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
241 request_serviced_callbacks_
.push_back(callback
);
244 void TestRequestInterceptor::Delegate::PushJobCallback(
245 const JobCallback
& callback
) {
246 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
247 pending_job_callbacks_
.push(callback
);
251 void TestRequestInterceptor::Delegate::InvokeRequestServicedCallbacks(
252 scoped_ptr
<std::vector
<base::Closure
>> callbacks
) {
253 for (const auto& p
: *callbacks
)
257 TestRequestInterceptor::TestRequestInterceptor(const std::string
& hostname
,
258 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
)
259 : hostname_(hostname
),
260 io_task_runner_(io_task_runner
) {
261 delegate_
= new Delegate(hostname_
, io_task_runner_
);
262 scoped_ptr
<net::URLRequestInterceptor
> interceptor(delegate_
);
264 base::Bind(&RegisterHttpInterceptor
, hostname_
,
265 base::Passed(&interceptor
)));
268 TestRequestInterceptor::~TestRequestInterceptor() {
269 // RemoveHostnameHandler() destroys the |delegate_|, which is owned by
270 // the URLRequestFilter.
273 base::Bind(&net::URLRequestFilter::RemoveHostnameHandler
,
274 base::Unretained(net::URLRequestFilter::GetInstance()),
278 size_t TestRequestInterceptor::GetPendingSize() {
279 size_t pending_size
= std::numeric_limits
<size_t>::max();
280 PostToIOAndWait(base::Bind(&Delegate::GetPendingSize
,
281 base::Unretained(delegate_
),
286 void TestRequestInterceptor::AddRequestServicedCallback(
287 const base::Closure
& callback
) {
288 base::Closure post_callback
=
289 base::Bind(base::IgnoreResult(&base::TaskRunner::PostTask
),
290 base::ThreadTaskRunnerHandle::Get(),
293 PostToIOAndWait(base::Bind(&Delegate::AddRequestServicedCallback
,
294 base::Unretained(delegate_
),
298 void TestRequestInterceptor::PushJobCallback(const JobCallback
& callback
) {
299 PostToIOAndWait(base::Bind(&Delegate::PushJobCallback
,
300 base::Unretained(delegate_
),
305 TestRequestInterceptor::JobCallback
TestRequestInterceptor::ErrorJob(
307 return base::Bind(&ErrorJobCallback
, error
);
311 TestRequestInterceptor::JobCallback
TestRequestInterceptor::BadRequestJob() {
312 return base::Bind(&BadRequestJobCallback
);
316 TestRequestInterceptor::JobCallback
TestRequestInterceptor::RegisterJob(
317 em::DeviceRegisterRequest::Type expected_type
,
318 bool expect_reregister
) {
319 return base::Bind(&RegisterJobCallback
, expected_type
, expect_reregister
);
323 TestRequestInterceptor::JobCallback
TestRequestInterceptor::FileJob(
324 const base::FilePath
& file_path
) {
325 return base::Bind(&FileJobCallback
, file_path
);
328 void TestRequestInterceptor::PostToIOAndWait(const base::Closure
& task
) {
329 io_task_runner_
->PostTask(FROM_HERE
, task
);
330 base::RunLoop run_loop
;
331 io_task_runner_
->PostTask(
334 base::IgnoreResult(&base::TaskRunner::PostTask
),
335 base::ThreadTaskRunnerHandle::Get(),
337 run_loop
.QuitClosure()));
341 } // namespace policy