1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/test/embedded_test_server/embedded_test_server.h"
8 #include "base/files/file_path.h"
9 #include "base/file_util.h"
10 #include "base/path_service.h"
11 #include "base/run_loop.h"
12 #include "base/stl_util.h"
13 #include "base/strings/string_util.h"
14 #include "base/strings/stringprintf.h"
15 #include "base/threading/thread_restrictions.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/net_errors.h"
18 #include "net/test/embedded_test_server/http_connection.h"
19 #include "net/test/embedded_test_server/http_request.h"
20 #include "net/test/embedded_test_server/http_response.h"
21 #include "net/tools/fetch/http_listen_socket.h"
24 namespace test_server
{
28 class CustomHttpResponse
: public HttpResponse
{
30 CustomHttpResponse(const std::string
& headers
, const std::string
& contents
)
31 : headers_(headers
), contents_(contents
) {
34 virtual std::string
ToResponseString() const OVERRIDE
{
35 return headers_
+ "\r\n" + contents_
;
40 std::string contents_
;
42 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse
);
45 // Handles |request| by serving a file from under |server_root|.
46 scoped_ptr
<HttpResponse
> HandleFileRequest(
47 const base::FilePath
& server_root
,
48 const HttpRequest
& request
) {
49 // This is a test-only server. Ignore I/O thread restrictions.
50 base::ThreadRestrictions::ScopedAllowIO allow_io
;
52 // Trim the first byte ('/').
53 std::string
request_path(request
.relative_url
.substr(1));
55 // Remove the query string if present.
56 size_t query_pos
= request_path
.find('?');
57 if (query_pos
!= std::string::npos
)
58 request_path
= request_path
.substr(0, query_pos
);
60 base::FilePath
file_path(server_root
.AppendASCII(request_path
));
61 std::string file_contents
;
62 if (!file_util::ReadFileToString(file_path
, &file_contents
))
63 return scoped_ptr
<HttpResponse
>();
65 base::FilePath
headers_path(
66 file_path
.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
68 if (base::PathExists(headers_path
)) {
69 std::string headers_contents
;
70 if (!file_util::ReadFileToString(headers_path
, &headers_contents
))
71 return scoped_ptr
<HttpResponse
>();
73 scoped_ptr
<CustomHttpResponse
> http_response(
74 new CustomHttpResponse(headers_contents
, file_contents
));
75 return http_response
.PassAs
<HttpResponse
>();
78 scoped_ptr
<BasicHttpResponse
> http_response(new BasicHttpResponse
);
79 http_response
->set_code(HTTP_OK
);
80 http_response
->set_content(file_contents
);
81 return http_response
.PassAs
<HttpResponse
>();
86 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor
,
87 StreamListenSocket::Delegate
* delegate
)
88 : TCPListenSocket(socket_descriptor
, delegate
) {
89 DCHECK(thread_checker_
.CalledOnValidThread());
92 void HttpListenSocket::Listen() {
93 DCHECK(thread_checker_
.CalledOnValidThread());
94 TCPListenSocket::Listen();
97 HttpListenSocket::~HttpListenSocket() {
98 DCHECK(thread_checker_
.CalledOnValidThread());
101 EmbeddedTestServer::EmbeddedTestServer(
102 const scoped_refptr
<base::SingleThreadTaskRunner
>& io_thread
)
103 : io_thread_(io_thread
),
105 weak_factory_(this) {
106 DCHECK(io_thread_
.get());
107 DCHECK(thread_checker_
.CalledOnValidThread());
110 EmbeddedTestServer::~EmbeddedTestServer() {
111 DCHECK(thread_checker_
.CalledOnValidThread());
113 if (Started() && !ShutdownAndWaitUntilComplete()) {
114 LOG(ERROR
) << "EmbeddedTestServer failed to shut down.";
118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
119 DCHECK(thread_checker_
.CalledOnValidThread());
121 base::RunLoop run_loop
;
122 if (!io_thread_
->PostTaskAndReply(
124 base::Bind(&EmbeddedTestServer::InitializeOnIOThread
,
125 base::Unretained(this)),
126 run_loop
.QuitClosure())) {
131 return Started() && base_url_
.is_valid();
134 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
135 DCHECK(thread_checker_
.CalledOnValidThread());
137 base::RunLoop run_loop
;
138 if (!io_thread_
->PostTaskAndReply(
140 base::Bind(&EmbeddedTestServer::ShutdownOnIOThread
,
141 base::Unretained(this)),
142 run_loop
.QuitClosure())) {
150 void EmbeddedTestServer::InitializeOnIOThread() {
151 DCHECK(io_thread_
->BelongsToCurrentThread());
154 SocketDescriptor socket_descriptor
=
155 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_
);
156 if (socket_descriptor
== TCPListenSocket::kInvalidSocket
)
159 listen_socket_
= new HttpListenSocket(socket_descriptor
, this);
160 listen_socket_
->Listen();
163 int result
= listen_socket_
->GetLocalAddress(&address
);
165 base_url_
= GURL(std::string("http://") + address
.ToString());
167 LOG(ERROR
) << "GetLocalAddress failed: " << ErrorToString(result
);
171 void EmbeddedTestServer::ShutdownOnIOThread() {
172 DCHECK(io_thread_
->BelongsToCurrentThread());
174 listen_socket_
= NULL
; // Release the listen socket.
175 STLDeleteContainerPairSecondPointers(connections_
.begin(),
177 connections_
.clear();
180 void EmbeddedTestServer::HandleRequest(HttpConnection
* connection
,
181 scoped_ptr
<HttpRequest
> request
) {
182 DCHECK(io_thread_
->BelongsToCurrentThread());
184 bool request_handled
= false;
186 for (size_t i
= 0; i
< request_handlers_
.size(); ++i
) {
187 scoped_ptr
<HttpResponse
> response
=
188 request_handlers_
[i
].Run(*request
.get());
189 if (response
.get()) {
190 connection
->SendResponse(response
.Pass());
191 request_handled
= true;
196 if (!request_handled
) {
197 LOG(WARNING
) << "Request not handled. Returning 404: "
198 << request
->relative_url
;
199 scoped_ptr
<BasicHttpResponse
> not_found_response(new BasicHttpResponse
);
200 not_found_response
->set_code(HTTP_NOT_FOUND
);
201 connection
->SendResponse(
202 not_found_response
.PassAs
<HttpResponse
>());
205 // Drop the connection, since we do not support multiple requests per
207 connections_
.erase(connection
->socket_
.get());
211 GURL
EmbeddedTestServer::GetURL(const std::string
& relative_url
) const {
212 DCHECK(StartsWithASCII(relative_url
, "/", true /* case_sensitive */))
214 return base_url_
.Resolve(relative_url
);
217 void EmbeddedTestServer::ServeFilesFromDirectory(
218 const base::FilePath
& directory
) {
219 RegisterRequestHandler(base::Bind(&HandleFileRequest
, directory
));
222 void EmbeddedTestServer::RegisterRequestHandler(
223 const HandleRequestCallback
& callback
) {
224 request_handlers_
.push_back(callback
);
227 void EmbeddedTestServer::DidAccept(StreamListenSocket
* server
,
228 StreamListenSocket
* connection
) {
229 DCHECK(io_thread_
->BelongsToCurrentThread());
231 HttpConnection
* http_connection
= new HttpConnection(
233 base::Bind(&EmbeddedTestServer::HandleRequest
,
234 weak_factory_
.GetWeakPtr()));
235 connections_
[connection
] = http_connection
;
238 void EmbeddedTestServer::DidRead(StreamListenSocket
* connection
,
241 DCHECK(io_thread_
->BelongsToCurrentThread());
243 HttpConnection
* http_connection
= FindConnection(connection
);
244 if (http_connection
== NULL
) {
245 LOG(WARNING
) << "Unknown connection.";
248 http_connection
->ReceiveData(std::string(data
, length
));
251 void EmbeddedTestServer::DidClose(StreamListenSocket
* connection
) {
252 DCHECK(io_thread_
->BelongsToCurrentThread());
254 HttpConnection
* http_connection
= FindConnection(connection
);
255 if (http_connection
== NULL
) {
256 LOG(WARNING
) << "Unknown connection.";
259 delete http_connection
;
260 connections_
.erase(connection
);
263 HttpConnection
* EmbeddedTestServer::FindConnection(
264 StreamListenSocket
* socket
) {
265 DCHECK(io_thread_
->BelongsToCurrentThread());
267 std::map
<StreamListenSocket
*, HttpConnection
*>::iterator it
=
268 connections_
.find(socket
);
269 if (it
== connections_
.end()) {
275 } // namespace test_server