1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
8 #include "base/bind_helpers.h"
9 #include "base/compiler_specific.h"
10 #include "base/format_macros.h"
11 #include "base/memory/ref_counted.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/message_loop/message_loop.h"
15 #include "base/message_loop/message_loop_proxy.h"
16 #include "base/run_loop.h"
17 #include "base/strings/string_split.h"
18 #include "base/strings/string_util.h"
19 #include "base/strings/stringprintf.h"
20 #include "base/time/time.h"
21 #include "net/base/address_list.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/ip_endpoint.h"
24 #include "net/base/net_errors.h"
25 #include "net/base/net_log.h"
26 #include "net/server/http_server.h"
27 #include "net/server/http_server_request_info.h"
28 #include "net/socket/tcp_client_socket.h"
29 #include "net/socket/tcp_listen_socket.h"
30 #include "net/url_request/url_fetcher.h"
31 #include "net/url_request/url_fetcher_delegate.h"
32 #include "net/url_request/url_request_context.h"
33 #include "net/url_request/url_request_context_getter.h"
34 #include "net/url_request/url_request_test_util.h"
35 #include "testing/gtest/include/gtest/gtest.h"
41 void SetTimedOutAndQuitLoop(const base::WeakPtr
<bool> timed_out
,
42 const base::Closure
& quit_loop_func
) {
49 bool RunLoopWithTimeout(base::RunLoop
* run_loop
) {
50 bool timed_out
= false;
51 base::WeakPtrFactory
<bool> timed_out_weak_factory(&timed_out
);
52 base::MessageLoop::current()->PostDelayedTask(
54 base::Bind(&SetTimedOutAndQuitLoop
,
55 timed_out_weak_factory
.GetWeakPtr(),
56 run_loop
->QuitClosure()),
57 base::TimeDelta::FromSeconds(1));
62 class TestHttpClient
{
64 TestHttpClient() : connect_result_(OK
) {}
66 int ConnectAndWait(const IPEndPoint
& address
) {
67 AddressList
addresses(address
);
68 NetLog::Source source
;
69 socket_
.reset(new TCPClientSocket(addresses
, NULL
, source
));
71 base::RunLoop run_loop
;
72 connect_result_
= socket_
->Connect(base::Bind(&TestHttpClient::OnConnect
,
73 base::Unretained(this),
74 run_loop
.QuitClosure()));
75 if (connect_result_
!= OK
&& connect_result_
!= ERR_IO_PENDING
)
76 return connect_result_
;
78 if (!RunLoopWithTimeout(&run_loop
))
80 return connect_result_
;
83 void Send(const std::string
& data
) {
85 new DrainableIOBuffer(new StringIOBuffer(data
), data
.length());
90 void OnConnect(const base::Closure
& quit_loop
, int result
) {
91 connect_result_
= result
;
96 int result
= socket_
->Write(
98 write_buffer_
->BytesRemaining(),
99 base::Bind(&TestHttpClient::OnWrite
, base::Unretained(this)));
100 if (result
!= ERR_IO_PENDING
)
104 void OnWrite(int result
) {
105 ASSERT_GT(result
, 0);
106 write_buffer_
->DidConsume(result
);
107 if (write_buffer_
->BytesRemaining())
111 scoped_refptr
<DrainableIOBuffer
> write_buffer_
;
112 scoped_ptr
<TCPClientSocket
> socket_
;
118 class HttpServerTest
: public testing::Test
,
119 public HttpServer::Delegate
{
121 HttpServerTest() : quit_after_request_count_(0) {}
123 virtual void SetUp() OVERRIDE
{
124 TCPListenSocketFactory
socket_factory("127.0.0.1", 0);
125 server_
= new HttpServer(socket_factory
, this);
126 ASSERT_EQ(OK
, server_
->GetLocalAddress(&server_address_
));
129 virtual void OnHttpRequest(int connection_id
,
130 const HttpServerRequestInfo
& info
) OVERRIDE
{
131 requests_
.push_back(info
);
132 if (requests_
.size() == quit_after_request_count_
)
133 run_loop_quit_func_
.Run();
136 virtual void OnWebSocketRequest(int connection_id
,
137 const HttpServerRequestInfo
& info
) OVERRIDE
{
141 virtual void OnWebSocketMessage(int connection_id
,
142 const std::string
& data
) OVERRIDE
{
146 virtual void OnClose(int connection_id
) OVERRIDE
{}
148 bool RunUntilRequestsReceived(size_t count
) {
149 quit_after_request_count_
= count
;
150 if (requests_
.size() == count
)
153 base::RunLoop run_loop
;
154 run_loop_quit_func_
= run_loop
.QuitClosure();
155 bool success
= RunLoopWithTimeout(&run_loop
);
156 run_loop_quit_func_
.Reset();
161 scoped_refptr
<HttpServer
> server_
;
162 IPEndPoint server_address_
;
163 base::Closure run_loop_quit_func_
;
164 std::vector
<HttpServerRequestInfo
> requests_
;
167 size_t quit_after_request_count_
;
170 TEST_F(HttpServerTest
, Request
) {
171 TestHttpClient client
;
172 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
173 client
.Send("GET /test HTTP/1.1\r\n\r\n");
174 ASSERT_TRUE(RunUntilRequestsReceived(1));
175 ASSERT_EQ("GET", requests_
[0].method
);
176 ASSERT_EQ("/test", requests_
[0].path
);
177 ASSERT_EQ("", requests_
[0].data
);
178 ASSERT_EQ(0u, requests_
[0].headers
.size());
181 TEST_F(HttpServerTest
, RequestWithHeaders
) {
182 TestHttpClient client
;
183 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
184 const char* kHeaders
[][3] = {
185 {"Header", ": ", "1"},
186 {"HeaderWithNoWhitespace", ":", "1"},
187 {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "},
188 {"HeaderWithColon", ": ", "1:1"},
189 {"EmptyHeader", ":", ""},
190 {"EmptyHeaderWithWhitespace", ": \t ", ""},
191 {"HeaderWithNonASCII", ": ", "\u00f7"},
194 for (size_t i
= 0; i
< arraysize(kHeaders
); ++i
) {
196 std::string(kHeaders
[i
][0]) + kHeaders
[i
][1] + kHeaders
[i
][2] + "\r\n";
199 client
.Send("GET /test HTTP/1.1\r\n" + headers
+ "\r\n");
200 ASSERT_TRUE(RunUntilRequestsReceived(1));
201 ASSERT_EQ("", requests_
[0].data
);
203 for (size_t i
= 0; i
< arraysize(kHeaders
); ++i
) {
204 std::string field
= StringToLowerASCII(std::string(kHeaders
[i
][0]));
205 std::string value
= kHeaders
[i
][2];
206 ASSERT_EQ(1u, requests_
[0].headers
.count(field
)) << field
;
207 ASSERT_EQ(value
, requests_
[0].headers
[field
]) << kHeaders
[i
][0];
211 TEST_F(HttpServerTest
, RequestWithBody
) {
212 TestHttpClient client
;
213 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
214 std::string body
= "a" + std::string(1 << 10, 'b') + "c";
215 client
.Send(base::StringPrintf(
216 "GET /test HTTP/1.1\r\n"
218 "Content-Length: %" PRIuS
"\r\n\r\n%s",
221 ASSERT_TRUE(RunUntilRequestsReceived(1));
222 ASSERT_EQ(2u, requests_
[0].headers
.size());
223 ASSERT_EQ(body
.length(), requests_
[0].data
.length());
224 ASSERT_EQ('a', body
[0]);
225 ASSERT_EQ('c', *body
.rbegin());
228 TEST_F(HttpServerTest
, RequestWithTooLargeBody
) {
229 class TestURLFetcherDelegate
: public URLFetcherDelegate
{
231 TestURLFetcherDelegate(const base::Closure
& quit_loop_func
)
232 : quit_loop_func_(quit_loop_func
) {}
233 virtual ~TestURLFetcherDelegate() {}
235 virtual void OnURLFetchComplete(const URLFetcher
* source
) OVERRIDE
{
236 EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR
, source
->GetResponseCode());
237 quit_loop_func_
.Run();
241 base::Closure quit_loop_func_
;
244 base::RunLoop run_loop
;
245 TestURLFetcherDelegate
delegate(run_loop
.QuitClosure());
247 scoped_refptr
<URLRequestContextGetter
> request_context_getter(
248 new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
249 scoped_ptr
<URLFetcher
> fetcher(
250 URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
251 server_address_
.port())),
254 fetcher
->SetRequestContext(request_context_getter
.get());
255 fetcher
->AddExtraRequestHeader(
256 base::StringPrintf("content-length:%d", 1 << 30));
259 ASSERT_TRUE(RunLoopWithTimeout(&run_loop
));
260 ASSERT_EQ(0u, requests_
.size());
265 class MockStreamListenSocket
: public StreamListenSocket
{
267 MockStreamListenSocket(StreamListenSocket::Delegate
* delegate
)
268 : StreamListenSocket(kInvalidSocket
, delegate
) {}
270 virtual void Accept() OVERRIDE
{ NOTREACHED(); }
273 virtual ~MockStreamListenSocket() {}
278 TEST_F(HttpServerTest
, RequestWithBodySplitAcrossPackets
) {
279 scoped_refptr
<StreamListenSocket
> socket(
280 new MockStreamListenSocket(server_
.get()));
281 server_
->DidAccept(NULL
, socket
.get());
282 std::string
body("body");
283 std::string request
= base::StringPrintf(
284 "GET /test HTTP/1.1\r\n"
286 "Content-Length: %" PRIuS
"\r\n\r\n%s",
289 server_
->DidRead(socket
.get(), request
.c_str(), request
.length() - 2);
290 ASSERT_EQ(0u, requests_
.size());
291 server_
->DidRead(socket
.get(), request
.c_str() + request
.length() - 2, 2);
292 ASSERT_EQ(1u, requests_
.size());
293 ASSERT_EQ(body
, requests_
[0].data
);
296 TEST_F(HttpServerTest
, MultipleRequestsOnSameConnection
) {
297 // The idea behind this test is that requests with or without bodies should
298 // not break parsing of the next request.
299 TestHttpClient client
;
300 ASSERT_EQ(OK
, client
.ConnectAndWait(server_address_
));
301 std::string body
= "body";
302 client
.Send(base::StringPrintf(
303 "GET /test HTTP/1.1\r\n"
304 "Content-Length: %" PRIuS
"\r\n\r\n%s",
307 ASSERT_TRUE(RunUntilRequestsReceived(1));
308 ASSERT_EQ(body
, requests_
[0].data
);
310 client
.Send("GET /test2 HTTP/1.1\r\n\r\n");
311 ASSERT_TRUE(RunUntilRequestsReceived(2));
312 ASSERT_EQ("/test2", requests_
[1].path
);
314 client
.Send("GET /test3 HTTP/1.1\r\n\r\n");
315 ASSERT_TRUE(RunUntilRequestsReceived(3));
316 ASSERT_EQ("/test3", requests_
[2].path
);