1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/policy/cloud/test_request_interceptor.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/run_loop.h"
14 #include "base/sequenced_task_runner.h"
15 #include "content/test/net/url_request_mock_http_job.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/upload_bytes_element_reader.h"
18 #include "net/base/upload_data_stream.h"
19 #include "net/base/upload_element_reader.h"
20 #include "net/url_request/url_request_error_job.h"
21 #include "net/url_request/url_request_filter.h"
22 #include "net/url_request/url_request_job_factory.h"
23 #include "net/url_request/url_request_test_job.h"
26 namespace em
= enterprise_management
;
32 // Helper callback for jobs that should fail with a network |error|.
33 net::URLRequestJob
* ErrorJobCallback(int error
,
34 net::URLRequest
* request
,
35 net::NetworkDelegate
* network_delegate
) {
36 return new net::URLRequestErrorJob(request
, network_delegate
, error
);
39 // Helper callback for jobs that should fail with a 400 HTTP error.
40 net::URLRequestJob
* BadRequestJobCallback(
41 net::URLRequest
* request
,
42 net::NetworkDelegate
* network_delegate
) {
43 static const char kBadHeaders
[] =
44 "HTTP/1.1 400 Bad request\0"
45 "Content-type: application/protobuf\0"
47 std::string
headers(kBadHeaders
, arraysize(kBadHeaders
));
48 return new net::URLRequestTestJob(
49 request
, network_delegate
, headers
, std::string(), true);
52 net::URLRequestJob
* FileJobCallback(const base::FilePath
& file_path
,
53 net::URLRequest
* request
,
54 net::NetworkDelegate
* network_delegate
) {
55 return new content::URLRequestMockHTTPJob(
61 // Parses the upload data in |request| into |request_msg|, and validates the
62 // request. The query string in the URL must contain the |expected_type| for
63 // the "request" parameter. Returns true if all checks succeeded, and the
64 // request data has been parsed into |request_msg|.
65 bool ValidRequest(net::URLRequest
* request
,
66 const std::string
& expected_type
,
67 em::DeviceManagementRequest
* request_msg
) {
68 if (request
->method() != "POST")
70 std::string spec
= request
->url().spec();
71 if (spec
.find("request=" + expected_type
) == std::string::npos
)
74 // This assumes that the payload data was set from a single string. In that
75 // case the UploadDataStream has a single UploadBytesElementReader with the
77 const net::UploadDataStream
* stream
= request
->get_upload();
80 const ScopedVector
<net::UploadElementReader
>& readers
=
81 stream
->element_readers();
82 if (readers
.size() != 1u)
84 const net::UploadBytesElementReader
* reader
= readers
[0]->AsBytesReader();
87 std::string
data(reader
->bytes(), reader
->length());
88 if (!request_msg
->ParseFromString(data
))
94 // Helper callback for register jobs that should suceed. Validates the request
95 // parameters and returns an appropriate response job. If |expect_reregister|
96 // is true then the reregister flag must be set in the DeviceRegisterRequest
98 net::URLRequestJob
* RegisterJobCallback(
99 em::DeviceRegisterRequest::Type expected_type
,
100 bool expect_reregister
,
101 net::URLRequest
* request
,
102 net::NetworkDelegate
* network_delegate
) {
103 em::DeviceManagementRequest request_msg
;
104 if (!ValidRequest(request
, "register", &request_msg
))
105 return BadRequestJobCallback(request
, network_delegate
);
107 if (!request_msg
.has_register_request() ||
108 request_msg
.has_unregister_request() ||
109 request_msg
.has_policy_request() ||
110 request_msg
.has_device_status_report_request() ||
111 request_msg
.has_session_status_report_request() ||
112 request_msg
.has_auto_enrollment_request()) {
113 return BadRequestJobCallback(request
, network_delegate
);
116 const em::DeviceRegisterRequest
& register_request
=
117 request_msg
.register_request();
118 if (expect_reregister
&&
119 (!register_request
.has_reregister() || !register_request
.reregister())) {
120 return BadRequestJobCallback(request
, network_delegate
);
121 } else if (!expect_reregister
&&
122 register_request
.has_reregister() &&
123 register_request
.reregister()) {
124 return BadRequestJobCallback(request
, network_delegate
);
127 if (!register_request
.has_type() || register_request
.type() != expected_type
)
128 return BadRequestJobCallback(request
, network_delegate
);
130 em::DeviceManagementResponse response
;
131 em::DeviceRegisterResponse
* register_response
=
132 response
.mutable_register_response();
133 register_response
->set_device_management_token("s3cr3t70k3n");
135 response
.SerializeToString(&data
);
137 static const char kGoodHeaders
[] =
139 "Content-type: application/protobuf\0"
141 std::string
headers(kGoodHeaders
, arraysize(kGoodHeaders
));
142 return new net::URLRequestTestJob(
143 request
, network_delegate
, headers
, data
, true);
148 class TestRequestInterceptor::Delegate
149 : public net::URLRequestJobFactory::ProtocolHandler
{
151 Delegate(const std::string
& hostname
,
152 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
);
155 // ProtocolHandler implementation:
156 virtual net::URLRequestJob
* MaybeCreateJob(
157 net::URLRequest
* request
,
158 net::NetworkDelegate
* network_delegate
) const OVERRIDE
;
160 void GetPendingSize(size_t* pending_size
) const;
161 void PushJobCallback(const JobCallback
& callback
);
164 const std::string hostname_
;
165 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner_
;
167 // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a
168 // const method; it can't reenter though, because it runs exclusively on
170 mutable std::queue
<JobCallback
> pending_job_callbacks_
;
173 TestRequestInterceptor::Delegate::Delegate(
174 const std::string
& hostname
,
175 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
)
176 : hostname_(hostname
), io_task_runner_(io_task_runner
) {}
178 TestRequestInterceptor::Delegate::~Delegate() {}
180 net::URLRequestJob
* TestRequestInterceptor::Delegate::MaybeCreateJob(
181 net::URLRequest
* request
,
182 net::NetworkDelegate
* network_delegate
) const {
183 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
185 if (request
->url().host() != hostname_
) {
186 // Reject requests to other servers.
187 return ErrorJobCallback(
188 net::ERR_CONNECTION_REFUSED
, request
, network_delegate
);
191 if (pending_job_callbacks_
.empty()) {
192 // Reject dmserver requests by default.
193 return BadRequestJobCallback(request
, network_delegate
);
196 JobCallback callback
= pending_job_callbacks_
.front();
197 pending_job_callbacks_
.pop();
198 return callback
.Run(request
, network_delegate
);
201 void TestRequestInterceptor::Delegate::GetPendingSize(
202 size_t* pending_size
) const {
203 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
204 *pending_size
= pending_job_callbacks_
.size();
207 void TestRequestInterceptor::Delegate::PushJobCallback(
208 const JobCallback
& callback
) {
209 CHECK(io_task_runner_
->RunsTasksOnCurrentThread());
210 pending_job_callbacks_
.push(callback
);
213 TestRequestInterceptor::TestRequestInterceptor(const std::string
& hostname
,
214 scoped_refptr
<base::SequencedTaskRunner
> io_task_runner
)
215 : hostname_(hostname
),
216 io_task_runner_(io_task_runner
) {
217 delegate_
= new Delegate(hostname_
, io_task_runner_
);
218 scoped_ptr
<net::URLRequestJobFactory::ProtocolHandler
> handler(delegate_
);
220 base::Bind(&net::URLRequestFilter::AddHostnameProtocolHandler
,
221 base::Unretained(net::URLRequestFilter::GetInstance()),
222 "http", hostname_
, base::Passed(&handler
)));
225 TestRequestInterceptor::~TestRequestInterceptor() {
226 // RemoveHostnameHandler() destroys the |delegate_|, which is owned by
227 // the URLRequestFilter.
230 base::Bind(&net::URLRequestFilter::RemoveHostnameHandler
,
231 base::Unretained(net::URLRequestFilter::GetInstance()),
235 size_t TestRequestInterceptor::GetPendingSize() {
236 size_t pending_size
= std::numeric_limits
<size_t>::max();
237 PostToIOAndWait(base::Bind(&Delegate::GetPendingSize
,
238 base::Unretained(delegate_
),
243 void TestRequestInterceptor::PushJobCallback(const JobCallback
& callback
) {
244 PostToIOAndWait(base::Bind(&Delegate::PushJobCallback
,
245 base::Unretained(delegate_
),
250 TestRequestInterceptor::JobCallback
TestRequestInterceptor::ErrorJob(
252 return base::Bind(&ErrorJobCallback
, error
);
256 TestRequestInterceptor::JobCallback
TestRequestInterceptor::BadRequestJob() {
257 return base::Bind(&BadRequestJobCallback
);
261 TestRequestInterceptor::JobCallback
TestRequestInterceptor::RegisterJob(
262 em::DeviceRegisterRequest::Type expected_type
,
263 bool expect_reregister
) {
264 return base::Bind(&RegisterJobCallback
, expected_type
, expect_reregister
);
268 TestRequestInterceptor::JobCallback
TestRequestInterceptor::FileJob(
269 const base::FilePath
& file_path
) {
270 return base::Bind(&FileJobCallback
, file_path
);
273 void TestRequestInterceptor::PostToIOAndWait(const base::Closure
& task
) {
274 io_task_runner_
->PostTask(FROM_HERE
, task
);
275 base::RunLoop run_loop
;
276 io_task_runner_
->PostTask(
279 base::IgnoreResult(&base::MessageLoopProxy::PostTask
),
280 base::MessageLoopProxy::current(),
282 run_loop
.QuitClosure()));
286 } // namespace policy