1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome_frame/test/test_server.h"
11 #include "base/bind.h"
12 #include "base/logging.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "base/strings/string_piece.h"
15 #include "base/strings/string_util.h"
16 #include "base/strings/stringprintf.h"
17 #include "base/strings/utf_string_conversions.h"
18 #include "chrome_frame/test/chrome_frame_test_utils.h"
19 #include "net/base/winsock_init.h"
20 #include "net/http/http_util.h"
21 #include "net/socket/tcp_listen_socket.h"
23 namespace test_server
{
24 const char kDefaultHeaderTemplate
[] =
26 "Connection: close\r\n"
27 "Content-Type: %hs\r\n"
28 "Content-Length: %i\r\n\r\n";
29 const char kStatusOk
[] = "200 OK";
30 const char kStatusNotFound
[] = "404 Not Found";
31 const char kDefaultContentType
[] = "text/html; charset=UTF-8";
33 void Request::ParseHeaders(const std::string
& headers
) {
34 DCHECK(method_
.length() == 0);
36 size_t pos
= headers
.find("\r\n");
37 DCHECK(pos
!= std::string::npos
);
38 if (pos
!= std::string::npos
) {
39 headers_
= headers
.substr(pos
+ 2);
41 base::StringTokenizer
tokenizer(
42 headers
.begin(), headers
.begin() + pos
, " ");
43 std::string
* parse
[] = { &method_
, &path_
, &version_
};
45 while (tokenizer
.GetNext() && field
< arraysize(parse
)) {
46 parse
[field
++]->assign(tokenizer
.token_begin(),
47 tokenizer
.token_end());
51 // Check for content-length in case we're being sent some data.
52 net::HttpUtil::HeadersIterator
it(headers_
.begin(), headers_
.end(),
54 while (it
.GetNext()) {
55 if (LowerCaseEqualsASCII(it
.name(), "content-length")) {
56 int int_content_length
;
57 base::StringToInt(base::StringPiece(it
.values_begin(),
60 content_length_
= int_content_length
;
66 void Request::OnDataReceived(const std::string
& data
) {
69 if (method_
.length() == 0) {
70 size_t index
= content_
.find("\r\n\r\n");
71 if (index
!= std::string::npos
) {
72 // Parse the headers before returning and chop them of the
73 // data buffer we've already received.
74 std::string
headers(content_
.substr(0, index
+ 2));
75 ParseHeaders(headers
);
76 content_
.erase(0, index
+ 4);
81 ResponseForPath::~ResponseForPath() {
84 SimpleResponse::~SimpleResponse() {
87 bool FileResponse::GetContentType(std::string
* content_type
) const {
88 size_t length
= ContentLength();
93 // Create a copy of the first few bytes of the file.
94 // If we try and use the mapped file directly, FindMimeFromData will crash
95 // 'cause it cheats and temporarily tries to write to the buffer!
96 length
= std::min(arraysize(buffer
), length
);
97 memcpy(buffer
, file_
->data(), length
);
101 LPOLESTR mime_type
= NULL
;
102 FindMimeFromData(NULL
, file_path_
.value().c_str(), data
, length
, NULL
,
103 FMFD_DEFAULT
, &mime_type
, 0);
105 *content_type
= WideToASCII(mime_type
);
106 ::CoTaskMemFree(mime_type
);
109 return content_type
->length() > 0;
112 void FileResponse::WriteContents(net::StreamListenSocket
* socket
) const {
115 socket
->Send(reinterpret_cast<const char*>(file_
->data()),
116 file_
->length(), false);
120 size_t FileResponse::ContentLength() const {
121 if (file_
.get() == NULL
) {
122 file_
.reset(new base::MemoryMappedFile());
123 if (!file_
->Initialize(file_path_
)) {
128 return file_
.get() ? file_
->length() : 0;
131 bool RedirectResponse::GetCustomHeaders(std::string
* headers
) const {
132 *headers
= base::StringPrintf("HTTP/1.1 302 Found\r\n"
133 "Connection: close\r\n"
134 "Content-Length: 0\r\n"
135 "Content-Type: text/html\r\n"
136 "Location: %hs\r\n\r\n",
137 redirect_url_
.c_str());
141 SimpleWebServer::SimpleWebServer(int port
) {
142 Construct(chrome_frame_test::GetLocalIPv4Address(), port
);
145 SimpleWebServer::SimpleWebServer(const std::string
& address
, int port
) {
146 Construct(address
, port
);
149 SimpleWebServer::~SimpleWebServer() {
150 ConnectionList::const_iterator it
;
151 for (it
= connections_
.begin(); it
!= connections_
.end(); ++it
)
153 connections_
.clear();
156 void SimpleWebServer::Construct(const std::string
& address
, int port
) {
157 CHECK(base::MessageLoop::current())
158 << "SimpleWebServer requires a message loop";
159 net::EnsureWinsockInit();
162 server_
= net::TCPListenSocket::CreateAndListen(address
, port
, this);
163 LOG_IF(DFATAL
, !server_
.get())
164 << "Failed to create listener socket at " << address
<< ":" << port
;
167 void SimpleWebServer::AddResponse(Response
* response
) {
168 responses_
.push_back(response
);
171 void SimpleWebServer::DeleteAllResponses() {
172 std::list
<Response
*>::const_iterator it
;
173 for (it
= responses_
.begin(); it
!= responses_
.end(); ++it
) {
179 Response
* SimpleWebServer::FindResponse(const Request
& request
) const {
180 std::list
<Response
*>::const_iterator it
;
181 for (it
= responses_
.begin(); it
!= responses_
.end(); it
++) {
182 Response
* response
= (*it
);
183 if (response
->Matches(request
)) {
190 Connection
* SimpleWebServer::FindConnection(
191 const net::StreamListenSocket
* socket
) const {
192 ConnectionList::const_iterator it
;
193 for (it
= connections_
.begin(); it
!= connections_
.end(); it
++) {
194 if ((*it
)->IsSame(socket
)) {
201 void SimpleWebServer::DidAccept(net::StreamListenSocket
* server
,
202 net::StreamListenSocket
* connection
) {
203 connections_
.push_back(new Connection(connection
));
206 void SimpleWebServer::DidRead(net::StreamListenSocket
* connection
,
209 Connection
* c
= FindConnection(connection
);
211 Request
& r
= c
->request();
212 std::string
str(data
, len
);
213 r
.OnDataReceived(str
);
214 if (r
.AllContentReceived()) {
215 const Request
& request
= c
->request();
216 Response
* response
= FindResponse(request
);
219 if (!response
->GetCustomHeaders(&headers
)) {
220 std::string content_type
;
221 if (!response
->GetContentType(&content_type
))
222 content_type
= kDefaultContentType
;
223 headers
= base::StringPrintf(kDefaultHeaderTemplate
, kStatusOk
,
224 content_type
.c_str(),
225 response
->ContentLength());
228 connection
->Send(headers
, false);
229 response
->WriteContents(connection
);
230 response
->IncrementAccessCounter();
232 std::string payload
= "sorry, I can't find " + request
.path();
233 std::string
headers(base::StringPrintf(kDefaultHeaderTemplate
,
237 connection
->Send(headers
, false);
238 connection
->Send(payload
, false);
243 void SimpleWebServer::DidClose(net::StreamListenSocket
* sock
) {
244 // To keep the historical list of connections reasonably tidy, we delete
245 // 404's when the connection ends.
246 Connection
* c
= FindConnection(sock
);
249 if (!FindResponse(c
->request())) {
250 // extremely inefficient, but in one line and not that common... :)
251 connections_
.erase(std::find(connections_
.begin(), connections_
.end(), c
));
256 HTTPTestServer::HTTPTestServer(int port
, const std::wstring
& address
,
257 base::FilePath root_dir
)
258 : port_(port
), address_(address
), root_dir_(root_dir
) {
259 net::EnsureWinsockInit();
261 net::TCPListenSocket::CreateAndListen(WideToUTF8(address
), port
, this);
264 HTTPTestServer::~HTTPTestServer() {
268 std::list
<scoped_refptr
<ConfigurableConnection
>>::iterator
269 HTTPTestServer::FindConnection(const net::StreamListenSocket
* socket
) {
270 ConnectionList::iterator it
;
271 // Scan through the list searching for the desired socket. Along the way,
272 // erase any connections for which the corresponding socket has already been
273 // forgotten about as a result of all data having been sent.
274 for (it
= connection_list_
.begin(); it
!= connection_list_
.end(); ) {
275 ConfigurableConnection
* connection
= it
->get();
276 if (connection
->socket_
== NULL
) {
277 connection_list_
.erase(it
++);
280 if (connection
->socket_
== socket
)
288 scoped_refptr
<ConfigurableConnection
> HTTPTestServer::ConnectionFromSocket(
289 const net::StreamListenSocket
* socket
) {
290 ConnectionList::iterator it
= FindConnection(socket
);
291 if (it
!= connection_list_
.end())
296 void HTTPTestServer::DidAccept(net::StreamListenSocket
* server
,
297 net::StreamListenSocket
* socket
) {
298 connection_list_
.push_back(new ConfigurableConnection(socket
));
301 void HTTPTestServer::DidRead(net::StreamListenSocket
* socket
,
304 scoped_refptr
<ConfigurableConnection
> connection
=
305 ConnectionFromSocket(socket
);
307 std::string
str(data
, len
);
308 connection
->r_
.OnDataReceived(str
);
309 if (connection
->r_
.AllContentReceived()) {
310 VLOG(1) << __FUNCTION__
<< ": " << connection
->r_
.method() << " "
311 << connection
->r_
.path();
312 std::wstring path
= UTF8ToWide(connection
->r_
.path());
313 if (LowerCaseEqualsASCII(connection
->r_
.method(), "post"))
314 this->Post(connection
, path
, connection
->r_
);
316 this->Get(connection
, path
, connection
->r_
);
321 void HTTPTestServer::DidClose(net::StreamListenSocket
* socket
) {
322 ConnectionList::iterator it
= FindConnection(socket
);
323 if (it
!= connection_list_
.end())
324 connection_list_
.erase(it
);
327 std::wstring
HTTPTestServer::Resolve(const std::wstring
& path
) {
328 // Remove the first '/' if needed.
329 std::wstring stripped_path
= path
;
330 if (path
.size() && path
[0] == L
'/')
331 stripped_path
= path
.substr(1);
334 if (stripped_path
.empty()) {
335 return base::StringPrintf(L
"http://%ls", address_
.c_str());
337 return base::StringPrintf(L
"http://%ls/%ls", address_
.c_str(),
338 stripped_path
.c_str());
341 if (stripped_path
.empty()) {
342 return base::StringPrintf(L
"http://%ls:%d", address_
.c_str(), port_
);
344 return base::StringPrintf(L
"http://%ls:%d/%ls", address_
.c_str(), port_
,
345 stripped_path
.c_str());
350 void ConfigurableConnection::SendChunk() {
351 int size
= (int)data_
.size();
352 const char* chunk_ptr
= data_
.c_str() + cur_pos_
;
353 int bytes_to_send
= std::min(options_
.chunk_size_
, size
- cur_pos_
);
355 socket_
->Send(chunk_ptr
, bytes_to_send
);
356 VLOG(1) << "Sent(" << cur_pos_
<< "," << bytes_to_send
<< "): "
357 << base::StringPiece(chunk_ptr
, bytes_to_send
);
359 cur_pos_
+= bytes_to_send
;
360 if (cur_pos_
< size
) {
361 base::MessageLoop::current()->PostDelayedTask(
362 FROM_HERE
, base::Bind(&ConfigurableConnection::SendChunk
, this),
363 base::TimeDelta::FromMilliseconds(options_
.timeout_
));
365 socket_
= 0; // close the connection.
369 void ConfigurableConnection::Close() {
373 void ConfigurableConnection::Send(const std::string
& headers
,
374 const std::string
& content
) {
375 SendOptions
options(SendOptions::IMMEDIATE
, 0, 0);
376 SendWithOptions(headers
, content
, options
);
379 void ConfigurableConnection::SendWithOptions(const std::string
& headers
,
380 const std::string
& content
,
381 const SendOptions
& options
) {
382 std::string content_length_header
;
383 if (!content
.empty() &&
384 std::string::npos
== headers
.find("Context-Length:")) {
385 content_length_header
= base::StringPrintf("Content-Length: %u\r\n",
392 if (options_
.speed_
== SendOptions::IMMEDIATE
) {
393 socket_
->Send(headers
);
394 socket_
->Send(content_length_header
, true);
395 socket_
->Send(content
);
396 // Post a task to close the socket since StreamListenSocket doesn't like
397 // instances to go away from within its callbacks.
398 base::MessageLoop::current()->PostTask(
399 FROM_HERE
, base::Bind(&ConfigurableConnection::Close
, this));
404 if (options_
.speed_
== SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT
) {
405 socket_
->Send(headers
);
406 socket_
->Send(content_length_header
, true);
407 VLOG(1) << "Headers sent: " << headers
<< content_length_header
;
408 data_
.append(content
);
411 if (options_
.speed_
== SendOptions::DELAYED
) {
413 data_
.append(content_length_header
);
414 data_
.append("\r\n");
417 base::MessageLoop::current()->PostDelayedTask(
418 FROM_HERE
, base::Bind(&ConfigurableConnection::SendChunk
, this),
419 base::TimeDelta::FromMilliseconds(options
.timeout_
));
422 } // namespace test_server