Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / net / test / embedded_test_server / embedded_test_server.cc
blob5b0079c84726624bef27340c997ae46bca2c2b99
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/test/embedded_test_server/embedded_test_server.h"
7 #include "base/bind.h"
8 #include "base/files/file_path.h"
9 #include "base/files/file_util.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/message_loop/message_loop.h"
13 #include "base/process/process_metrics.h"
14 #include "base/run_loop.h"
15 #include "base/stl_util.h"
16 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h"
18 #include "base/thread_task_runner_handle.h"
19 #include "base/threading/thread_restrictions.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/base/net_errors.h"
22 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
23 #include "net/test/embedded_test_server/http_connection.h"
24 #include "net/test/embedded_test_server/http_request.h"
25 #include "net/test/embedded_test_server/http_response.h"
27 namespace net {
28 namespace test_server {
30 namespace {
32 class CustomHttpResponse : public HttpResponse {
33 public:
34 CustomHttpResponse(const std::string& headers, const std::string& contents)
35 : headers_(headers), contents_(contents) {
38 std::string ToResponseString() const override {
39 return headers_ + "\r\n" + contents_;
42 private:
43 std::string headers_;
44 std::string contents_;
46 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
49 // Handles |request| by serving a file from under |server_root|.
50 scoped_ptr<HttpResponse> HandleFileRequest(
51 const base::FilePath& server_root,
52 const HttpRequest& request) {
53 // This is a test-only server. Ignore I/O thread restrictions.
54 base::ThreadRestrictions::ScopedAllowIO allow_io;
56 std::string relative_url(request.relative_url);
57 // A proxy request will have an absolute path. Simulate the proxy by stripping
58 // the scheme, host, and port.
59 GURL relative_gurl(relative_url);
60 if (relative_gurl.is_valid())
61 relative_url = relative_gurl.PathForRequest();
63 // Trim the first byte ('/').
64 std::string request_path = relative_url.substr(1);
66 // Remove the query string if present.
67 size_t query_pos = request_path.find('?');
68 if (query_pos != std::string::npos)
69 request_path = request_path.substr(0, query_pos);
71 base::FilePath file_path(server_root.AppendASCII(request_path));
72 std::string file_contents;
73 if (!base::ReadFileToString(file_path, &file_contents))
74 return scoped_ptr<HttpResponse>();
76 base::FilePath headers_path(
77 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
79 if (base::PathExists(headers_path)) {
80 std::string headers_contents;
81 if (!base::ReadFileToString(headers_path, &headers_contents))
82 return scoped_ptr<HttpResponse>();
84 scoped_ptr<CustomHttpResponse> http_response(
85 new CustomHttpResponse(headers_contents, file_contents));
86 return http_response.Pass();
89 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
90 http_response->set_code(HTTP_OK);
91 http_response->set_content(file_contents);
92 return http_response.Pass();
95 } // namespace
97 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
98 StreamListenSocket::Delegate* delegate)
99 : TCPListenSocket(socket_descriptor, delegate) {
100 DCHECK(thread_checker_.CalledOnValidThread());
103 void HttpListenSocket::Listen() {
104 DCHECK(thread_checker_.CalledOnValidThread());
105 TCPListenSocket::Listen();
108 void HttpListenSocket::ListenOnIOThread() {
109 DCHECK(thread_checker_.CalledOnValidThread());
110 #if !defined(OS_POSIX)
111 // This method may be called after the IO thread is changed, thus we need to
112 // call |WatchSocket| again to make sure it listens on the current IO thread.
113 // Only needed for non POSIX platforms, since on POSIX platforms
114 // StreamListenSocket::Listen already calls WatchSocket inside the function.
115 WatchSocket(WAITING_ACCEPT);
116 #endif
117 Listen();
120 HttpListenSocket::~HttpListenSocket() {
121 DCHECK(thread_checker_.CalledOnValidThread());
124 void HttpListenSocket::DetachFromThread() {
125 thread_checker_.DetachFromThread();
128 EmbeddedTestServer::EmbeddedTestServer()
129 : connection_listener_(nullptr), port_(0), weak_factory_(this) {
130 DCHECK(thread_checker_.CalledOnValidThread());
133 EmbeddedTestServer::~EmbeddedTestServer() {
134 DCHECK(thread_checker_.CalledOnValidThread());
136 if (Started() && !ShutdownAndWaitUntilComplete()) {
137 LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
141 void EmbeddedTestServer::SetConnectionListener(
142 EmbeddedTestServerConnectionListener* listener) {
143 DCHECK(!Started());
144 connection_listener_ = listener;
147 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
148 StartThread();
149 DCHECK(thread_checker_.CalledOnValidThread());
150 if (!PostTaskToIOThreadAndWait(base::Bind(
151 &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) {
152 return false;
154 return Started() && base_url_.is_valid();
157 void EmbeddedTestServer::StopThread() {
158 DCHECK(io_thread_ && io_thread_->IsRunning());
160 #if defined(OS_LINUX)
161 const int thread_count =
162 base::GetNumberOfThreads(base::GetCurrentProcessHandle());
163 #endif
165 io_thread_->Stop();
166 io_thread_.reset();
167 thread_checker_.DetachFromThread();
168 listen_socket_->DetachFromThread();
170 #if defined(OS_LINUX)
171 // Busy loop to wait for thread count to decrease. This is needed because
172 // pthread_join does not guarantee that kernel stat is updated when it
173 // returns. Thus, GetNumberOfThreads does not immediately reflect the stopped
174 // thread and hits the thread number DCHECK in render_sandbox_host_linux.cc
175 // in browser_tests.
176 while (thread_count ==
177 base::GetNumberOfThreads(base::GetCurrentProcessHandle())) {
178 base::PlatformThread::YieldCurrentThread();
180 #endif
183 void EmbeddedTestServer::RestartThreadAndListen() {
184 StartThread();
185 CHECK(PostTaskToIOThreadAndWait(base::Bind(
186 &EmbeddedTestServer::ListenOnIOThread, base::Unretained(this))));
189 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
190 DCHECK(thread_checker_.CalledOnValidThread());
192 return PostTaskToIOThreadAndWait(base::Bind(
193 &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this)));
196 void EmbeddedTestServer::StartThread() {
197 DCHECK(!io_thread_.get());
198 base::Thread::Options thread_options;
199 thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
200 io_thread_.reset(new base::Thread("EmbeddedTestServer io thread"));
201 CHECK(io_thread_->StartWithOptions(thread_options));
202 CHECK(io_thread_->WaitUntilThreadStarted());
205 void EmbeddedTestServer::InitializeOnIOThread() {
206 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
207 DCHECK(!Started());
209 SocketDescriptor socket_descriptor =
210 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
211 if (socket_descriptor == kInvalidSocket)
212 return;
214 listen_socket_.reset(new HttpListenSocket(socket_descriptor, this));
215 listen_socket_->Listen();
217 IPEndPoint address;
218 int result = listen_socket_->GetLocalAddress(&address);
219 if (result == OK) {
220 base_url_ = GURL(std::string("http://") + address.ToString());
221 } else {
222 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
226 void EmbeddedTestServer::ListenOnIOThread() {
227 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
228 DCHECK(Started());
229 listen_socket_->ListenOnIOThread();
232 void EmbeddedTestServer::ShutdownOnIOThread() {
233 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
235 listen_socket_.reset();
236 STLDeleteContainerPairSecondPointers(connections_.begin(),
237 connections_.end());
238 connections_.clear();
241 void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
242 scoped_ptr<HttpRequest> request) {
243 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
245 bool request_handled = false;
247 for (size_t i = 0; i < request_handlers_.size(); ++i) {
248 scoped_ptr<HttpResponse> response =
249 request_handlers_[i].Run(*request.get());
250 if (response.get()) {
251 connection->SendResponse(response.Pass());
252 request_handled = true;
253 break;
257 if (!request_handled) {
258 LOG(WARNING) << "Request not handled. Returning 404: "
259 << request->relative_url;
260 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
261 not_found_response->set_code(HTTP_NOT_FOUND);
262 connection->SendResponse(not_found_response.Pass());
265 // Drop the connection, since we do not support multiple requests per
266 // connection.
267 connections_.erase(connection->socket_.get());
268 delete connection;
271 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
272 DCHECK(Started()) << "You must start the server first.";
273 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE))
274 << relative_url;
275 return base_url_.Resolve(relative_url);
278 GURL EmbeddedTestServer::GetURL(
279 const std::string& hostname,
280 const std::string& relative_url) const {
281 GURL local_url = GetURL(relative_url);
282 GURL::Replacements replace_host;
283 replace_host.SetHostStr(hostname);
284 return local_url.ReplaceComponents(replace_host);
287 bool EmbeddedTestServer::GetAddressList(net::AddressList* address_list) const {
288 if (!listen_socket_)
289 return false;
290 IPEndPoint endpoint;
291 int result = listen_socket_->GetLocalAddress(&endpoint);
292 if (result != OK)
293 return false;
295 *address_list = net::AddressList(endpoint);
296 return true;
299 void EmbeddedTestServer::ServeFilesFromDirectory(
300 const base::FilePath& directory) {
301 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
304 void EmbeddedTestServer::RegisterRequestHandler(
305 const HandleRequestCallback& callback) {
306 request_handlers_.push_back(callback);
309 void EmbeddedTestServer::DidAccept(StreamListenSocket* server,
310 scoped_ptr<StreamListenSocket> connection) {
311 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
312 if (connection_listener_)
313 connection_listener_->AcceptedSocket(*connection);
315 HttpConnection* http_connection = new HttpConnection(
316 connection.Pass(), base::Bind(&EmbeddedTestServer::HandleRequest,
317 weak_factory_.GetWeakPtr()));
318 // TODO(szym): Make HttpConnection the StreamListenSocket delegate.
319 connections_[http_connection->socket_.get()] = http_connection;
322 void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
323 const char* data,
324 int length) {
325 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
326 if (connection_listener_)
327 connection_listener_->ReadFromSocket(*connection);
329 HttpConnection* http_connection = FindConnection(connection);
330 if (http_connection == NULL) {
331 LOG(WARNING) << "Unknown connection.";
332 return;
334 http_connection->ReceiveData(std::string(data, length));
337 void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
338 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
340 HttpConnection* http_connection = FindConnection(connection);
341 if (http_connection == NULL) {
342 LOG(WARNING) << "Unknown connection.";
343 return;
345 delete http_connection;
346 connections_.erase(connection);
349 HttpConnection* EmbeddedTestServer::FindConnection(
350 StreamListenSocket* socket) {
351 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
353 std::map<StreamListenSocket*, HttpConnection*>::iterator it =
354 connections_.find(socket);
355 if (it == connections_.end()) {
356 return NULL;
358 return it->second;
361 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(
362 const base::Closure& closure) {
363 // Note that PostTaskAndReply below requires
364 // base::ThreadTaskRunnerHandle::Get() to return a task runner for posting
365 // the reply task. However, in order to make EmbeddedTestServer universally
366 // usable, it needs to cope with the situation where it's running on a thread
367 // on which a message loop is not (yet) available or as has been destroyed
368 // already.
370 // To handle this situation, create temporary message loop to support the
371 // PostTaskAndReply operation if the current thread as no message loop.
372 scoped_ptr<base::MessageLoop> temporary_loop;
373 if (!base::MessageLoop::current())
374 temporary_loop.reset(new base::MessageLoop());
376 base::RunLoop run_loop;
377 if (!io_thread_->task_runner()->PostTaskAndReply(FROM_HERE, closure,
378 run_loop.QuitClosure())) {
379 return false;
381 run_loop.Run();
383 return true;
386 } // namespace test_server
387 } // namespace net