1 // Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
9 #include "base/logging.h"
10 #include "base/string_number_conversions.h"
11 #include "base/string_util.h"
12 #include "base/utf_string_conversions.h"
14 #include "chrome_frame/test/test_server.h"
16 #include "net/base/winsock_init.h"
17 #include "net/http/http_util.h"
19 namespace test_server
{
20 const char kDefaultHeaderTemplate
[] =
22 "Connection: close\r\n"
23 "Content-Type: %hs\r\n"
24 "Content-Length: %i\r\n\r\n";
25 const char kStatusOk
[] = "200 OK";
26 const char kStatusNotFound
[] = "404 Not Found";
27 const char kDefaultContentType
[] = "text/html; charset=UTF-8";
29 void Request::ParseHeaders(const std::string
& headers
) {
30 DCHECK(method_
.length() == 0);
32 size_t pos
= headers
.find("\r\n");
33 DCHECK(pos
!= std::string::npos
);
34 if (pos
!= std::string::npos
) {
35 headers_
= headers
.substr(pos
+ 2);
37 StringTokenizer
tokenizer(headers
.begin(), headers
.begin() + pos
, " ");
38 std::string
* parse
[] = { &method_
, &path_
, &version_
};
40 while (tokenizer
.GetNext() && field
< arraysize(parse
)) {
41 parse
[field
++]->assign(tokenizer
.token_begin(),
42 tokenizer
.token_end());
46 // Check for content-length in case we're being sent some data.
47 net::HttpUtil::HeadersIterator
it(headers_
.begin(), headers_
.end(),
49 while (it
.GetNext()) {
50 if (LowerCaseEqualsASCII(it
.name(), "content-length")) {
51 int int_content_length
;
52 base::StringToInt(it
.values().c_str(), &int_content_length
);
53 content_length_
= int_content_length
;
59 void Request::OnDataReceived(const std::string
& data
) {
62 if (method_
.length() == 0) {
63 size_t index
= content_
.find("\r\n\r\n");
64 if (index
!= std::string::npos
) {
65 // Parse the headers before returning and chop them of the
66 // data buffer we've already received.
67 std::string
headers(content_
.substr(0, index
+ 2));
68 ParseHeaders(headers
);
69 content_
.erase(0, index
+ 4);
74 bool FileResponse::GetContentType(std::string
* content_type
) const {
75 size_t length
= ContentLength();
80 // Create a copy of the first few bytes of the file.
81 // If we try and use the mapped file directly, FindMimeFromData will crash
82 // 'cause it cheats and temporarily tries to write to the buffer!
83 length
= std::min(arraysize(buffer
), length
);
84 memcpy(buffer
, file_
->data(), length
);
88 LPOLESTR mime_type
= NULL
;
89 FindMimeFromData(NULL
, file_path_
.value().c_str(), data
, length
, NULL
,
90 FMFD_DEFAULT
, &mime_type
, 0);
92 *content_type
= WideToASCII(mime_type
);
93 ::CoTaskMemFree(mime_type
);
96 return content_type
->length() > 0;
99 void FileResponse::WriteContents(ListenSocket
* socket
) const {
102 socket
->Send(reinterpret_cast<const char*>(file_
->data()),
103 file_
->length(), false);
107 size_t FileResponse::ContentLength() const {
108 if (file_
.get() == NULL
) {
109 file_
.reset(new file_util::MemoryMappedFile());
110 if (!file_
->Initialize(file_path_
)) {
115 return file_
.get() ? file_
->length() : 0;
118 bool RedirectResponse::GetCustomHeaders(std::string
* headers
) const {
119 *headers
= StringPrintf("HTTP/1.1 302 Found\r\n"
120 "Connection: close\r\n"
121 "Content-Length: 0\r\n"
122 "Content-Type: text/html\r\n"
123 "Location: %hs\r\n\r\n", redirect_url_
.c_str());
127 SimpleWebServer::SimpleWebServer(int port
) {
128 CHECK(MessageLoop::current()) << "SimpleWebServer requires a message loop";
129 net::EnsureWinsockInit();
131 server_
= ListenSocket::Listen("127.0.0.1", port
, this);
132 DCHECK(server_
.get() != NULL
);
135 SimpleWebServer::~SimpleWebServer() {
136 ConnectionList::const_iterator it
;
137 for (it
= connections_
.begin(); it
!= connections_
.end(); ++it
)
139 connections_
.clear();
142 void SimpleWebServer::AddResponse(Response
* response
) {
143 responses_
.push_back(response
);
146 void SimpleWebServer::DeleteAllResponses() {
147 std::list
<Response
*>::const_iterator it
;
148 for (it
= responses_
.begin(); it
!= responses_
.end(); ++it
) {
152 connections_
.clear();
155 Response
* SimpleWebServer::FindResponse(const Request
& request
) const {
156 std::list
<Response
*>::const_iterator it
;
157 for (it
= responses_
.begin(); it
!= responses_
.end(); it
++) {
158 Response
* response
= (*it
);
159 if (response
->Matches(request
)) {
166 Connection
* SimpleWebServer::FindConnection(const ListenSocket
* socket
) const {
167 ConnectionList::const_iterator it
;
168 for (it
= connections_
.begin(); it
!= connections_
.end(); it
++) {
169 if ((*it
)->IsSame(socket
)) {
176 void SimpleWebServer::DidAccept(ListenSocket
* server
,
177 ListenSocket
* connection
) {
178 connections_
.push_back(new Connection(connection
));
181 void SimpleWebServer::DidRead(ListenSocket
* connection
,
184 Connection
* c
= FindConnection(connection
);
186 Request
& r
= c
->request();
187 std::string
str(data
, len
);
188 r
.OnDataReceived(str
);
189 if (r
.AllContentReceived()) {
190 const Request
& request
= c
->request();
191 Response
* response
= FindResponse(request
);
194 if (!response
->GetCustomHeaders(&headers
)) {
195 std::string content_type
;
196 if (!response
->GetContentType(&content_type
))
197 content_type
= kDefaultContentType
;
198 headers
= StringPrintf(kDefaultHeaderTemplate
, kStatusOk
,
199 content_type
.c_str(), response
->ContentLength());
202 connection
->Send(headers
, false);
203 response
->WriteContents(connection
);
204 response
->IncrementAccessCounter();
206 std::string payload
= "sorry, I can't find " + request
.path();
207 std::string
headers(StringPrintf(kDefaultHeaderTemplate
, kStatusNotFound
,
208 kDefaultContentType
, payload
.length()));
209 connection
->Send(headers
, false);
210 connection
->Send(payload
, false);
215 void SimpleWebServer::DidClose(ListenSocket
* sock
) {
216 // To keep the historical list of connections reasonably tidy, we delete
217 // 404's when the connection ends.
218 Connection
* c
= FindConnection(sock
);
220 if (!FindResponse(c
->request())) {
221 // extremely inefficient, but in one line and not that common... :)
222 connections_
.erase(std::find(connections_
.begin(), connections_
.end(), c
));
227 HTTPTestServer::HTTPTestServer(int port
, const std::wstring
& address
,
229 : port_(port
), address_(address
), root_dir_(root_dir
) {
230 net::EnsureWinsockInit();
231 server_
= ListenSocket::Listen(WideToUTF8(address
), port
, this);
234 HTTPTestServer::~HTTPTestServer() {
237 std::list
<scoped_refptr
<ConfigurableConnection
>>::iterator
238 HTTPTestServer::FindConnection(const ListenSocket
* socket
) {
239 ConnectionList::iterator it
;
240 for (it
= connection_list_
.begin(); it
!= connection_list_
.end(); ++it
) {
241 if ((*it
)->socket_
== socket
) {
249 scoped_refptr
<ConfigurableConnection
> HTTPTestServer::ConnectionFromSocket(
250 const ListenSocket
* socket
) {
251 ConnectionList::iterator it
= FindConnection(socket
);
252 if (it
!= connection_list_
.end())
257 void HTTPTestServer::DidAccept(ListenSocket
* server
, ListenSocket
* socket
) {
258 connection_list_
.push_back(new ConfigurableConnection(socket
));
261 void HTTPTestServer::DidRead(ListenSocket
* socket
,
264 scoped_refptr
<ConfigurableConnection
> connection
=
265 ConnectionFromSocket(socket
);
267 std::string
str(data
, len
);
268 connection
->r_
.OnDataReceived(str
);
269 if (connection
->r_
.AllContentReceived()) {
270 std::wstring path
= UTF8ToWide(connection
->r_
.path());
271 if (LowerCaseEqualsASCII(connection
->r_
.method(), "post"))
272 this->Post(connection
, path
, connection
->r_
);
274 this->Get(connection
, path
, connection
->r_
);
279 void HTTPTestServer::DidClose(ListenSocket
* socket
) {
280 ConnectionList::iterator it
= FindConnection(socket
);
281 DCHECK(it
!= connection_list_
.end());
282 connection_list_
.erase(it
);
285 std::wstring
HTTPTestServer::Resolve(const std::wstring
& path
) {
286 // Remove the first '/' if needed.
287 std::wstring stripped_path
= path
;
288 if (path
.size() && path
[0] == L
'/')
289 stripped_path
= path
.substr(1);
292 if (stripped_path
.empty()) {
293 return StringPrintf(L
"http://%ls", address_
.c_str());
295 return StringPrintf(L
"http://%ls/%ls", address_
.c_str(),
296 stripped_path
.c_str());
299 if (stripped_path
.empty()) {
300 return StringPrintf(L
"http://%ls:%d", address_
.c_str(), port_
);
302 return StringPrintf(L
"http://%ls:%d/%ls", address_
.c_str(), port_
,
303 stripped_path
.c_str());
308 void ConfigurableConnection::SendChunk() {
309 int size
= (int)data_
.size();
310 const char* chunk_ptr
= data_
.c_str() + cur_pos_
;
311 int bytes_to_send
= std::min(options_
.chunk_size_
, size
- cur_pos_
);
313 socket_
->Send(chunk_ptr
, bytes_to_send
);
314 DLOG(INFO
) << "Sent(" << cur_pos_
<< "," << bytes_to_send
315 << "): " << base::StringPiece(chunk_ptr
, bytes_to_send
);
317 cur_pos_
+= bytes_to_send
;
318 if (cur_pos_
< size
) {
319 MessageLoop::current()->PostDelayedTask(FROM_HERE
, NewRunnableMethod(this,
320 &ConfigurableConnection::SendChunk
), options_
.timeout_
);
322 socket_
= 0; // close the connection.
326 void ConfigurableConnection::Send(const std::string
& headers
,
327 const std::string
& content
) {
328 SendOptions
options(SendOptions::IMMEDIATE
, 0, 0);
329 SendWithOptions(headers
, content
, options
);
332 void ConfigurableConnection::SendWithOptions(const std::string
& headers
,
333 const std::string
& content
,
334 const SendOptions
& options
) {
335 std::string content_length_header
;
336 if (!content
.empty() &&
337 std::string::npos
== headers
.find("Context-Length:")) {
338 content_length_header
= StringPrintf("Content-Length: %u\r\n",
345 if (options_
.speed_
== SendOptions::IMMEDIATE
) {
346 socket_
->Send(headers
);
347 socket_
->Send(content_length_header
, true);
348 socket_
->Send(content
);
349 socket_
= 0; // close the connection.
353 if (options_
.speed_
== SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT
) {
354 socket_
->Send(headers
);
355 socket_
->Send(content_length_header
, true);
356 DLOG(INFO
) << "Headers sent: " << headers
<< content_length_header
;
357 data_
.append(content
);
360 if (options_
.speed_
== SendOptions::DELAYED
) {
362 data_
.append(content_length_header
);
363 data_
.append("\r\n");
366 MessageLoop::current()->PostDelayedTask(FROM_HERE
,
367 NewRunnableMethod(this, &ConfigurableConnection::SendChunk
),
371 } // namespace test_server