Revert 268405 "Make sure that ScratchBuffer::Allocate() always r..."
[chromium-blink-merge.git] / net / socket / socks_client_socket_unittest.cc
blobf361244feff01a7dfa3138a7bc4ded1f36b13e29
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks_client_socket.h"
7 #include "base/memory/scoped_ptr.h"
8 #include "net/base/address_list.h"
9 #include "net/base/net_log.h"
10 #include "net/base/net_log_unittest.h"
11 #include "net/base/test_completion_callback.h"
12 #include "net/base/winsock_init.h"
13 #include "net/dns/mock_host_resolver.h"
14 #include "net/socket/client_socket_factory.h"
15 #include "net/socket/socket_test_util.h"
16 #include "net/socket/tcp_client_socket.h"
17 #include "testing/gtest/include/gtest/gtest.h"
18 #include "testing/platform_test.h"
20 //-----------------------------------------------------------------------------
22 namespace net {
24 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
25 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
27 class SOCKSClientSocketTest : public PlatformTest {
28 public:
29 SOCKSClientSocketTest();
30 // Create a SOCKSClientSocket on top of a MockSocket.
31 scoped_ptr<SOCKSClientSocket> BuildMockSocket(
32 MockRead reads[], size_t reads_count,
33 MockWrite writes[], size_t writes_count,
34 HostResolver* host_resolver,
35 const std::string& hostname, int port,
36 NetLog* net_log);
37 virtual void SetUp();
39 protected:
40 scoped_ptr<SOCKSClientSocket> user_sock_;
41 AddressList address_list_;
42 // Filled in by BuildMockSocket() and owned by its return value
43 // (which |user_sock| is set to).
44 StreamSocket* tcp_sock_;
45 TestCompletionCallback callback_;
46 scoped_ptr<MockHostResolver> host_resolver_;
47 scoped_ptr<SocketDataProvider> data_;
50 SOCKSClientSocketTest::SOCKSClientSocketTest()
51 : host_resolver_(new MockHostResolver) {
54 // Set up platform before every test case
55 void SOCKSClientSocketTest::SetUp() {
56 PlatformTest::SetUp();
59 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
60 MockRead reads[],
61 size_t reads_count,
62 MockWrite writes[],
63 size_t writes_count,
64 HostResolver* host_resolver,
65 const std::string& hostname,
66 int port,
67 NetLog* net_log) {
69 TestCompletionCallback callback;
70 data_.reset(new StaticSocketDataProvider(reads, reads_count,
71 writes, writes_count));
72 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
74 int rv = tcp_sock_->Connect(callback.callback());
75 EXPECT_EQ(ERR_IO_PENDING, rv);
76 rv = callback.WaitForResult();
77 EXPECT_EQ(OK, rv);
78 EXPECT_TRUE(tcp_sock_->IsConnected());
80 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
81 // |connection| takes ownership of |tcp_sock_|, but keep a
82 // non-owning pointer to it.
83 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
84 return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
85 connection.Pass(),
86 HostResolver::RequestInfo(HostPortPair(hostname, port)),
87 DEFAULT_PRIORITY,
88 host_resolver));
91 // Implementation of HostResolver that never completes its resolve request.
92 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
93 // sure that the outstanding resolve request gets cancelled.
94 class HangingHostResolverWithCancel : public HostResolver {
95 public:
96 HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
98 virtual int Resolve(const RequestInfo& info,
99 RequestPriority priority,
100 AddressList* addresses,
101 const CompletionCallback& callback,
102 RequestHandle* out_req,
103 const BoundNetLog& net_log) OVERRIDE {
104 DCHECK(addresses);
105 DCHECK_EQ(false, callback.is_null());
106 EXPECT_FALSE(HasOutstandingRequest());
107 outstanding_request_ = reinterpret_cast<RequestHandle>(1);
108 *out_req = outstanding_request_;
109 return ERR_IO_PENDING;
112 virtual int ResolveFromCache(const RequestInfo& info,
113 AddressList* addresses,
114 const BoundNetLog& net_log) OVERRIDE {
115 NOTIMPLEMENTED();
116 return ERR_UNEXPECTED;
119 virtual void CancelRequest(RequestHandle req) OVERRIDE {
120 EXPECT_TRUE(HasOutstandingRequest());
121 EXPECT_EQ(outstanding_request_, req);
122 outstanding_request_ = NULL;
125 bool HasOutstandingRequest() {
126 return outstanding_request_ != NULL;
129 private:
130 RequestHandle outstanding_request_;
132 DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
135 // Tests a complete handshake and the disconnection.
136 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
137 const std::string payload_write = "random data";
138 const std::string payload_read = "moar random data";
140 MockWrite data_writes[] = {
141 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
142 MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
143 MockRead data_reads[] = {
144 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
145 MockRead(ASYNC, payload_read.data(), payload_read.size()) };
146 CapturingNetLog log;
148 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
149 data_writes, arraysize(data_writes),
150 host_resolver_.get(),
151 "localhost", 80,
152 &log);
154 // At this state the TCP connection is completed but not the SOCKS handshake.
155 EXPECT_TRUE(tcp_sock_->IsConnected());
156 EXPECT_FALSE(user_sock_->IsConnected());
158 int rv = user_sock_->Connect(callback_.callback());
159 EXPECT_EQ(ERR_IO_PENDING, rv);
161 CapturingNetLog::CapturedEntryList entries;
162 log.GetEntries(&entries);
163 EXPECT_TRUE(
164 LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
165 EXPECT_FALSE(user_sock_->IsConnected());
167 rv = callback_.WaitForResult();
168 EXPECT_EQ(OK, rv);
169 EXPECT_TRUE(user_sock_->IsConnected());
170 log.GetEntries(&entries);
171 EXPECT_TRUE(LogContainsEndEvent(
172 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
174 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
175 memcpy(buffer->data(), payload_write.data(), payload_write.size());
176 rv = user_sock_->Write(
177 buffer.get(), payload_write.size(), callback_.callback());
178 EXPECT_EQ(ERR_IO_PENDING, rv);
179 rv = callback_.WaitForResult();
180 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
182 buffer = new IOBuffer(payload_read.size());
183 rv =
184 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
185 EXPECT_EQ(ERR_IO_PENDING, rv);
186 rv = callback_.WaitForResult();
187 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
188 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
190 user_sock_->Disconnect();
191 EXPECT_FALSE(tcp_sock_->IsConnected());
192 EXPECT_FALSE(user_sock_->IsConnected());
195 // List of responses from the socks server and the errors they should
196 // throw up are tested here.
197 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
198 const struct {
199 const char fail_reply[8];
200 Error fail_code;
201 } tests[] = {
202 // Failure of the server response code
204 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
205 ERR_SOCKS_CONNECTION_FAILED,
207 // Failure of the null byte
209 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
210 ERR_SOCKS_CONNECTION_FAILED,
214 //---------------------------------------
216 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
217 MockWrite data_writes[] = {
218 MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
219 MockRead data_reads[] = {
220 MockRead(SYNCHRONOUS, tests[i].fail_reply,
221 arraysize(tests[i].fail_reply)) };
222 CapturingNetLog log;
224 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
225 data_writes, arraysize(data_writes),
226 host_resolver_.get(),
227 "localhost", 80,
228 &log);
230 int rv = user_sock_->Connect(callback_.callback());
231 EXPECT_EQ(ERR_IO_PENDING, rv);
233 CapturingNetLog::CapturedEntryList entries;
234 log.GetEntries(&entries);
235 EXPECT_TRUE(LogContainsBeginEvent(
236 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
238 rv = callback_.WaitForResult();
239 EXPECT_EQ(tests[i].fail_code, rv);
240 EXPECT_FALSE(user_sock_->IsConnected());
241 EXPECT_TRUE(tcp_sock_->IsConnected());
242 log.GetEntries(&entries);
243 EXPECT_TRUE(LogContainsEndEvent(
244 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
248 // Tests scenario when the server sends the handshake response in
249 // more than one packet.
250 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
251 const char kSOCKSPartialReply1[] = { 0x00 };
252 const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
254 MockWrite data_writes[] = {
255 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
256 MockRead data_reads[] = {
257 MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
258 MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
259 CapturingNetLog log;
261 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
262 data_writes, arraysize(data_writes),
263 host_resolver_.get(),
264 "localhost", 80,
265 &log);
267 int rv = user_sock_->Connect(callback_.callback());
268 EXPECT_EQ(ERR_IO_PENDING, rv);
269 CapturingNetLog::CapturedEntryList entries;
270 log.GetEntries(&entries);
271 EXPECT_TRUE(LogContainsBeginEvent(
272 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
274 rv = callback_.WaitForResult();
275 EXPECT_EQ(OK, rv);
276 EXPECT_TRUE(user_sock_->IsConnected());
277 log.GetEntries(&entries);
278 EXPECT_TRUE(LogContainsEndEvent(
279 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
282 // Tests scenario when the client sends the handshake request in
283 // more than one packet.
284 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
285 const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
286 const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
288 MockWrite data_writes[] = {
289 MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
290 // simulate some empty writes
291 MockWrite(ASYNC, 0),
292 MockWrite(ASYNC, 0),
293 MockWrite(ASYNC, kSOCKSPartialRequest2,
294 arraysize(kSOCKSPartialRequest2)) };
295 MockRead data_reads[] = {
296 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
297 CapturingNetLog log;
299 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
300 data_writes, arraysize(data_writes),
301 host_resolver_.get(),
302 "localhost", 80,
303 &log);
305 int rv = user_sock_->Connect(callback_.callback());
306 EXPECT_EQ(ERR_IO_PENDING, rv);
307 CapturingNetLog::CapturedEntryList entries;
308 log.GetEntries(&entries);
309 EXPECT_TRUE(LogContainsBeginEvent(
310 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
312 rv = callback_.WaitForResult();
313 EXPECT_EQ(OK, rv);
314 EXPECT_TRUE(user_sock_->IsConnected());
315 log.GetEntries(&entries);
316 EXPECT_TRUE(LogContainsEndEvent(
317 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
320 // Tests the case when the server sends a smaller sized handshake data
321 // and closes the connection.
322 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
323 MockWrite data_writes[] = {
324 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
325 MockRead data_reads[] = {
326 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
327 // close connection unexpectedly
328 MockRead(SYNCHRONOUS, 0) };
329 CapturingNetLog log;
331 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
332 data_writes, arraysize(data_writes),
333 host_resolver_.get(),
334 "localhost", 80,
335 &log);
337 int rv = user_sock_->Connect(callback_.callback());
338 EXPECT_EQ(ERR_IO_PENDING, rv);
339 CapturingNetLog::CapturedEntryList entries;
340 log.GetEntries(&entries);
341 EXPECT_TRUE(LogContainsBeginEvent(
342 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
344 rv = callback_.WaitForResult();
345 EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
346 EXPECT_FALSE(user_sock_->IsConnected());
347 log.GetEntries(&entries);
348 EXPECT_TRUE(LogContainsEndEvent(
349 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
352 // Tries to connect to an unknown hostname. Should fail rather than
353 // falling back to SOCKS4a.
354 TEST_F(SOCKSClientSocketTest, FailedDNS) {
355 const char hostname[] = "unresolved.ipv4.address";
357 host_resolver_->rules()->AddSimulatedFailure(hostname);
359 CapturingNetLog log;
361 user_sock_ = BuildMockSocket(NULL, 0,
362 NULL, 0,
363 host_resolver_.get(),
364 hostname, 80,
365 &log);
367 int rv = user_sock_->Connect(callback_.callback());
368 EXPECT_EQ(ERR_IO_PENDING, rv);
369 CapturingNetLog::CapturedEntryList entries;
370 log.GetEntries(&entries);
371 EXPECT_TRUE(LogContainsBeginEvent(
372 entries, 0, NetLog::TYPE_SOCKS_CONNECT));
374 rv = callback_.WaitForResult();
375 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
376 EXPECT_FALSE(user_sock_->IsConnected());
377 log.GetEntries(&entries);
378 EXPECT_TRUE(LogContainsEndEvent(
379 entries, -1, NetLog::TYPE_SOCKS_CONNECT));
382 // Calls Disconnect() while a host resolve is in progress. The outstanding host
383 // resolve should be cancelled.
384 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
385 scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
386 new HangingHostResolverWithCancel());
388 // Doesn't matter what the socket data is, we will never use it -- garbage.
389 MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
390 MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
392 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
393 data_writes, arraysize(data_writes),
394 hanging_resolver.get(),
395 "foo", 80,
396 NULL);
398 // Start connecting (will get stuck waiting for the host to resolve).
399 int rv = user_sock_->Connect(callback_.callback());
400 EXPECT_EQ(ERR_IO_PENDING, rv);
402 EXPECT_FALSE(user_sock_->IsConnected());
403 EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
405 // The host resolver should have received the resolve request.
406 EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
408 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
409 user_sock_->Disconnect();
411 EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
413 EXPECT_FALSE(user_sock_->IsConnected());
414 EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
417 } // namespace net