1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
9 #include "base/bind_helpers.h"
10 #include "base/compiler_specific.h"
11 #include "base/format_macros.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/weak_ptr.h"
15 #include "base/message_loop/message_loop.h"
16 #include "base/message_loop/message_loop_proxy.h"
17 #include "base/run_loop.h"
18 #include "base/strings/string_split.h"
19 #include "base/strings/string_util.h"
20 #include "base/strings/stringprintf.h"
21 #include "base/time/time.h"
22 #include "net/base/address_list.h"
23 #include "net/base/io_buffer.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/net_log.h"
27 #include "net/base/test_completion_callback.h"
28 #include "net/server/http_server.h"
29 #include "net/server/http_server_request_info.h"
30 #include "net/socket/tcp_client_socket.h"
31 #include "net/socket/tcp_listen_socket.h"
32 #include "net/url_request/url_fetcher.h"
33 #include "net/url_request/url_fetcher_delegate.h"
34 #include "net/url_request/url_request_context.h"
35 #include "net/url_request/url_request_context_getter.h"
36 #include "net/url_request/url_request_test_util.h"
37 #include "testing/gtest/include/gtest/gtest.h"
43 const int kMaxExpectedResponseLength
= 2048;
45 void SetTimedOutAndQuitLoop(const base::WeakPtr
<bool> timed_out
,
46 const base::Closure
& quit_loop_func
) {
53 bool RunLoopWithTimeout(base::RunLoop
* run_loop
) {
54 bool timed_out
= false;
55 base::WeakPtrFactory
<bool> timed_out_weak_factory(&timed_out
);
56 base::MessageLoop::current()->PostDelayedTask(
58 base::Bind(&SetTimedOutAndQuitLoop
,
59 timed_out_weak_factory
.GetWeakPtr(),
60 run_loop
->QuitClosure()),
61 base::TimeDelta::FromSeconds(1));
66 class TestHttpClient
{
68 TestHttpClient() : connect_result_(OK
) {}
70 int ConnectAndWait(const IPEndPoint
& address
) {
71 AddressList
addresses(address
);
72 NetLog::Source source
;
73 socket_
.reset(new TCPClientSocket(addresses
, NULL
, source
));
75 base::RunLoop run_loop
;
76 connect_result_
= socket_
->Connect(base::Bind(&TestHttpClient::OnConnect
,
77 base::Unretained(this),
78 run_loop
.QuitClosure()));
79 if (connect_result_
!= OK
&& connect_result_
!= ERR_IO_PENDING
)
80 return connect_result_
;
82 if (!RunLoopWithTimeout(&run_loop
))
84 return connect_result_
;
87 void Send(const std::string
& data
) {
89 new DrainableIOBuffer(new StringIOBuffer(data
), data
.length());
93 bool Read(std::string
* message
) {
94 return Read(message
, 1);
97 bool Read(std::string
* message
, int expected_bytes
) {
98 int total_bytes_received
= 0;
100 while (total_bytes_received
< expected_bytes
) {
101 net::TestCompletionCallback callback
;
102 ReadInternal(callback
.callback());
103 int bytes_received
= callback
.WaitForResult();
104 if (bytes_received
<= 0)
107 total_bytes_received
+= bytes_received
;
108 message
->append(read_buffer_
->data(), bytes_received
);
114 void OnConnect(const base::Closure
& quit_loop
, int result
) {
115 connect_result_
= result
;
120 int result
= socket_
->Write(
122 write_buffer_
->BytesRemaining(),
123 base::Bind(&TestHttpClient::OnWrite
, base::Unretained(this)));
124 if (result
!= ERR_IO_PENDING
)
128 void OnWrite(int result
) {
129 ASSERT_GT(result
, 0);
130 write_buffer_
->DidConsume(result
);
131 if (write_buffer_
->BytesRemaining())
135 void ReadInternal(const net::CompletionCallback
& callback
) {
136 read_buffer_
= new IOBufferWithSize(kMaxExpectedResponseLength
);
137 int result
= socket_
->Read(read_buffer_
,
138 kMaxExpectedResponseLength
,
140 if (result
!= ERR_IO_PENDING
)
141 callback
.Run(result
);
144 scoped_refptr
<IOBufferWithSize
> read_buffer_
;
145 scoped_refptr
<DrainableIOBuffer
> write_buffer_
;
146 scoped_ptr
<TCPClientSocket
> socket_
;
152 class HttpServerTest
: public testing::Test
,
153 public HttpServer::Delegate
{
155 HttpServerTest() : quit_after_request_count_(0) {}
157 virtual void SetUp() OVERRIDE
{
158 TCPListenSocketFactory
socket_factory("127.0.0.1", 0);
159 server_
= new HttpServer(socket_factory
, this);
160 ASSERT_EQ(OK
, server_
->GetLocalAddress(&server_address_
));
163 virtual void OnHttpRequest(int connection_id
,
164 const HttpServerRequestInfo
& info
) OVERRIDE
{
165 requests_
.push_back(std::make_pair(info
, connection_id
));
166 if (requests_
.size() == quit_after_request_count_
)
167 run_loop_quit_func_
.Run();
170 virtual void OnWebSocketRequest(int connection_id
,
171 const HttpServerRequestInfo
& info
) OVERRIDE
{
175 virtual void OnWebSocketMessage(int connection_id
,
176 const std::string
& data
) OVERRIDE
{
180 virtual void OnClose(int connection_id
) OVERRIDE
{}
182 bool RunUntilRequestsReceived(size_t count
) {
183 quit_after_request_count_
= count
;
184 if (requests_
.size() == count
)
187 base::RunLoop run_loop
;
188 run_loop_quit_func_
= run_loop
.QuitClosure();
189 bool success
= RunLoopWithTimeout(&run_loop
);
190 run_loop_quit_func_
.Reset();
194 HttpServerRequestInfo
GetRequest(size_t request_index
) {
195 return requests_
[request_index
].first
;
198 int GetConnectionId(size_t request_index
) {
199 return requests_
[request_index
].second
;
203 scoped_refptr
<HttpServer
> server_
;
204 IPEndPoint server_address_
;
205 base::Closure run_loop_quit_func_
;
206 std::vector
<std::pair
<HttpServerRequestInfo
, int> > requests_
;
209 size_t quit_after_request_count_
;
212 TEST_F(HttpServerTest
, Request
) {
213 TestHttpClient client
;
214 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
215 client
.Send("GET /test HTTP/1.1\r\n\r\n");
216 ASSERT_TRUE(RunUntilRequestsReceived(1));
217 ASSERT_EQ("GET", GetRequest(0).method
);
218 ASSERT_EQ("/test", GetRequest(0).path
);
219 ASSERT_EQ("", GetRequest(0).data
);
220 ASSERT_EQ(0u, GetRequest(0).headers
.size());
221 ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer
.ToString(),
226 TEST_F(HttpServerTest
, RequestWithHeaders
) {
227 TestHttpClient client
;
228 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
229 const char* kHeaders
[][3] = {
230 {"Header", ": ", "1"},
231 {"HeaderWithNoWhitespace", ":", "1"},
232 {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "},
233 {"HeaderWithColon", ": ", "1:1"},
234 {"EmptyHeader", ":", ""},
235 {"EmptyHeaderWithWhitespace", ": \t ", ""},
236 {"HeaderWithNonASCII", ": ", "\xf7"},
239 for (size_t i
= 0; i
< arraysize(kHeaders
); ++i
) {
241 std::string(kHeaders
[i
][0]) + kHeaders
[i
][1] + kHeaders
[i
][2] + "\r\n";
244 client
.Send("GET /test HTTP/1.1\r\n" + headers
+ "\r\n");
245 ASSERT_TRUE(RunUntilRequestsReceived(1));
246 ASSERT_EQ("", GetRequest(0).data
);
248 for (size_t i
= 0; i
< arraysize(kHeaders
); ++i
) {
249 std::string field
= StringToLowerASCII(std::string(kHeaders
[i
][0]));
250 std::string value
= kHeaders
[i
][2];
251 ASSERT_EQ(1u, GetRequest(0).headers
.count(field
)) << field
;
252 ASSERT_EQ(value
, GetRequest(0).headers
[field
]) << kHeaders
[i
][0];
256 TEST_F(HttpServerTest
, RequestWithBody
) {
257 TestHttpClient client
;
258 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
259 std::string body
= "a" + std::string(1 << 10, 'b') + "c";
260 client
.Send(base::StringPrintf(
261 "GET /test HTTP/1.1\r\n"
263 "Content-Length: %" PRIuS
"\r\n\r\n%s",
266 ASSERT_TRUE(RunUntilRequestsReceived(1));
267 ASSERT_EQ(2u, GetRequest(0).headers
.size());
268 ASSERT_EQ(body
.length(), GetRequest(0).data
.length());
269 ASSERT_EQ('a', body
[0]);
270 ASSERT_EQ('c', *body
.rbegin());
273 TEST_F(HttpServerTest
, RequestWithTooLargeBody
) {
274 class TestURLFetcherDelegate
: public URLFetcherDelegate
{
276 TestURLFetcherDelegate(const base::Closure
& quit_loop_func
)
277 : quit_loop_func_(quit_loop_func
) {}
278 virtual ~TestURLFetcherDelegate() {}
280 virtual void OnURLFetchComplete(const URLFetcher
* source
) OVERRIDE
{
281 EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR
, source
->GetResponseCode());
282 quit_loop_func_
.Run();
286 base::Closure quit_loop_func_
;
289 base::RunLoop run_loop
;
290 TestURLFetcherDelegate
delegate(run_loop
.QuitClosure());
292 scoped_refptr
<URLRequestContextGetter
> request_context_getter(
293 new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
294 scoped_ptr
<URLFetcher
> fetcher(
295 URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
296 server_address_
.port())),
299 fetcher
->SetRequestContext(request_context_getter
.get());
300 fetcher
->AddExtraRequestHeader(
301 base::StringPrintf("content-length:%d", 1 << 30));
304 ASSERT_TRUE(RunLoopWithTimeout(&run_loop
));
305 ASSERT_EQ(0u, requests_
.size());
308 TEST_F(HttpServerTest
, Send200
) {
309 TestHttpClient client
;
310 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
311 client
.Send("GET /test HTTP/1.1\r\n\r\n");
312 ASSERT_TRUE(RunUntilRequestsReceived(1));
313 server_
->Send200(GetConnectionId(0), "Response!", "text/plain");
315 std::string response
;
316 ASSERT_TRUE(client
.Read(&response
));
317 ASSERT_TRUE(StartsWithASCII(response
, "HTTP/1.1 200 OK", true));
318 ASSERT_TRUE(EndsWith(response
, "Response!", true));
321 TEST_F(HttpServerTest
, SendRaw
) {
322 TestHttpClient client
;
323 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
324 client
.Send("GET /test HTTP/1.1\r\n\r\n");
325 ASSERT_TRUE(RunUntilRequestsReceived(1));
326 server_
->SendRaw(GetConnectionId(0), "Raw Data ");
327 server_
->SendRaw(GetConnectionId(0), "More Data");
328 server_
->SendRaw(GetConnectionId(0), "Third Piece of Data");
330 const std::string
expected_response("Raw Data More DataThird Piece of Data");
331 std::string response
;
332 ASSERT_TRUE(client
.Read(&response
, expected_response
.length()));
333 ASSERT_EQ(expected_response
, response
);
338 class MockStreamListenSocket
: public StreamListenSocket
{
340 MockStreamListenSocket(StreamListenSocket::Delegate
* delegate
)
341 : StreamListenSocket(kInvalidSocket
, delegate
) {}
343 virtual void Accept() OVERRIDE
{ NOTREACHED(); }
346 virtual ~MockStreamListenSocket() {}
351 TEST_F(HttpServerTest
, RequestWithBodySplitAcrossPackets
) {
352 StreamListenSocket
* socket
=
353 new MockStreamListenSocket(server_
.get());
354 server_
->DidAccept(NULL
, make_scoped_ptr(socket
));
355 std::string
body("body");
356 std::string request_text
= base::StringPrintf(
357 "GET /test HTTP/1.1\r\n"
359 "Content-Length: %" PRIuS
"\r\n\r\n%s",
362 server_
->DidRead(socket
, request_text
.c_str(), request_text
.length() - 2);
363 ASSERT_EQ(0u, requests_
.size());
364 server_
->DidRead(socket
, request_text
.c_str() + request_text
.length() - 2, 2);
365 ASSERT_EQ(1u, requests_
.size());
366 ASSERT_EQ(body
, GetRequest(0).data
);
369 TEST_F(HttpServerTest
, MultipleRequestsOnSameConnection
) {
370 // The idea behind this test is that requests with or without bodies should
371 // not break parsing of the next request.
372 TestHttpClient client
;
373 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
374 std::string body
= "body";
375 client
.Send(base::StringPrintf(
376 "GET /test HTTP/1.1\r\n"
377 "Content-Length: %" PRIuS
"\r\n\r\n%s",
380 ASSERT_TRUE(RunUntilRequestsReceived(1));
381 ASSERT_EQ(body
, GetRequest(0).data
);
383 int client_connection_id
= GetConnectionId(0);
384 server_
->Send200(client_connection_id
, "Content for /test", "text/plain");
385 std::string response1
;
386 ASSERT_TRUE(client
.Read(&response1
));
387 ASSERT_TRUE(StartsWithASCII(response1
, "HTTP/1.1 200 OK", true));
388 ASSERT_TRUE(EndsWith(response1
, "Content for /test", true));
390 client
.Send("GET /test2 HTTP/1.1\r\n\r\n");
391 ASSERT_TRUE(RunUntilRequestsReceived(2));
392 ASSERT_EQ("/test2", GetRequest(1).path
);
394 ASSERT_EQ(client_connection_id
, GetConnectionId(1));
395 server_
->Send404(client_connection_id
);
396 std::string response2
;
397 ASSERT_TRUE(client
.Read(&response2
));
398 ASSERT_TRUE(StartsWithASCII(response2
, "HTTP/1.1 404 Not Found", true));
400 client
.Send("GET /test3 HTTP/1.1\r\n\r\n");
401 ASSERT_TRUE(RunUntilRequestsReceived(3));
402 ASSERT_EQ("/test3", GetRequest(2).path
);
404 ASSERT_EQ(client_connection_id
, GetConnectionId(2));
405 server_
->Send200(client_connection_id
, "Content for /test3", "text/plain");
406 std::string response3
;
407 ASSERT_TRUE(client
.Read(&response3
));
408 ASSERT_TRUE(StartsWithASCII(response3
, "HTTP/1.1 200 OK", true));
409 ASSERT_TRUE(EndsWith(response3
, "Content for /test3", true));