Use EXPECT_EQ when possible.
[chromium-blink-merge.git] / net / socket / socks5_client_socket_unittest.cc
blob5bcc146fc1b24704df45e7f1151a8bc43c5f8a03
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks5_client_socket.h"
7 #include <algorithm>
8 #include <iterator>
9 #include <map>
11 #include "base/sys_byteorder.h"
12 #include "net/base/address_list.h"
13 #include "net/base/net_log.h"
14 #include "net/base/net_log_unittest.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/base/winsock_init.h"
17 #include "net/dns/mock_host_resolver.h"
18 #include "net/socket/client_socket_factory.h"
19 #include "net/socket/socket_test_util.h"
20 #include "net/socket/tcp_client_socket.h"
21 #include "testing/gtest/include/gtest/gtest.h"
22 #include "testing/platform_test.h"
24 //-----------------------------------------------------------------------------
26 namespace net {
28 namespace {
30 // Base class to test SOCKS5ClientSocket
31 class SOCKS5ClientSocketTest : public PlatformTest {
32 public:
33 SOCKS5ClientSocketTest();
34 // Create a SOCKSClientSocket on top of a MockSocket.
35 scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[],
36 size_t reads_count,
37 MockWrite writes[],
38 size_t writes_count,
39 const std::string& hostname,
40 int port,
41 NetLog* net_log);
43 void SetUp() override;
45 protected:
46 const uint16 kNwPort;
47 CapturingNetLog net_log_;
48 scoped_ptr<SOCKS5ClientSocket> user_sock_;
49 AddressList address_list_;
50 // Filled in by BuildMockSocket() and owned by its return value
51 // (which |user_sock| is set to).
52 StreamSocket* tcp_sock_;
53 TestCompletionCallback callback_;
54 scoped_ptr<MockHostResolver> host_resolver_;
55 scoped_ptr<SocketDataProvider> data_;
57 private:
58 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
61 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
62 : kNwPort(base::HostToNet16(80)),
63 host_resolver_(new MockHostResolver) {
66 // Set up platform before every test case
67 void SOCKS5ClientSocketTest::SetUp() {
68 PlatformTest::SetUp();
70 // Resolve the "localhost" AddressList used by the TCP connection to connect.
71 HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
72 TestCompletionCallback callback;
73 int rv = host_resolver_->Resolve(info,
74 DEFAULT_PRIORITY,
75 &address_list_,
76 callback.callback(),
77 NULL,
78 BoundNetLog());
79 ASSERT_EQ(ERR_IO_PENDING, rv);
80 rv = callback.WaitForResult();
81 ASSERT_EQ(OK, rv);
84 scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
85 MockRead reads[],
86 size_t reads_count,
87 MockWrite writes[],
88 size_t writes_count,
89 const std::string& hostname,
90 int port,
91 NetLog* net_log) {
92 TestCompletionCallback callback;
93 data_.reset(new StaticSocketDataProvider(reads, reads_count,
94 writes, writes_count));
95 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
97 int rv = tcp_sock_->Connect(callback.callback());
98 EXPECT_EQ(ERR_IO_PENDING, rv);
99 rv = callback.WaitForResult();
100 EXPECT_EQ(OK, rv);
101 EXPECT_TRUE(tcp_sock_->IsConnected());
103 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
104 // |connection| takes ownership of |tcp_sock_|, but keep a
105 // non-owning pointer to it.
106 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
107 return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket(
108 connection.Pass(),
109 HostResolver::RequestInfo(HostPortPair(hostname, port))));
112 // Tests a complete SOCKS5 handshake and the disconnection.
113 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
114 const std::string payload_write = "random data";
115 const std::string payload_read = "moar random data";
117 const char kOkRequest[] = {
118 0x05, // Version
119 0x01, // Command (CONNECT)
120 0x00, // Reserved.
121 0x03, // Address type (DOMAINNAME).
122 0x09, // Length of domain (9)
123 // Domain string:
124 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
125 0x00, 0x50, // 16-bit port (80)
128 MockWrite data_writes[] = {
129 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
130 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)),
131 MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
132 MockRead data_reads[] = {
133 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
134 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
135 MockRead(ASYNC, payload_read.data(), payload_read.size()) };
137 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
138 data_writes, arraysize(data_writes),
139 "localhost", 80, &net_log_);
141 // At this state the TCP connection is completed but not the SOCKS handshake.
142 EXPECT_TRUE(tcp_sock_->IsConnected());
143 EXPECT_FALSE(user_sock_->IsConnected());
145 int rv = user_sock_->Connect(callback_.callback());
146 EXPECT_EQ(ERR_IO_PENDING, rv);
147 EXPECT_FALSE(user_sock_->IsConnected());
149 CapturingNetLog::CapturedEntryList net_log_entries;
150 net_log_.GetEntries(&net_log_entries);
151 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
152 NetLog::TYPE_SOCKS5_CONNECT));
154 rv = callback_.WaitForResult();
156 EXPECT_EQ(OK, rv);
157 EXPECT_TRUE(user_sock_->IsConnected());
159 net_log_.GetEntries(&net_log_entries);
160 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
161 NetLog::TYPE_SOCKS5_CONNECT));
163 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
164 memcpy(buffer->data(), payload_write.data(), payload_write.size());
165 rv = user_sock_->Write(
166 buffer.get(), payload_write.size(), callback_.callback());
167 EXPECT_EQ(ERR_IO_PENDING, rv);
168 rv = callback_.WaitForResult();
169 EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
171 buffer = new IOBuffer(payload_read.size());
172 rv =
173 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
174 EXPECT_EQ(ERR_IO_PENDING, rv);
175 rv = callback_.WaitForResult();
176 EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
177 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
179 user_sock_->Disconnect();
180 EXPECT_FALSE(tcp_sock_->IsConnected());
181 EXPECT_FALSE(user_sock_->IsConnected());
184 // Test that you can call Connect() again after having called Disconnect().
185 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
186 const std::string hostname = "my-host-name";
187 const char kSOCKS5DomainRequest[] = {
188 0x05, // VER
189 0x01, // CMD
190 0x00, // RSV
191 0x03, // ATYPE
194 std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
195 request.push_back(hostname.size());
196 request.append(hostname);
197 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
199 for (int i = 0; i < 2; ++i) {
200 MockWrite data_writes[] = {
201 MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
202 MockWrite(SYNCHRONOUS, request.data(), request.size())
204 MockRead data_reads[] = {
205 MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
206 MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
209 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
210 data_writes, arraysize(data_writes),
211 hostname, 80, NULL);
213 int rv = user_sock_->Connect(callback_.callback());
214 EXPECT_EQ(OK, rv);
215 EXPECT_TRUE(user_sock_->IsConnected());
217 user_sock_->Disconnect();
218 EXPECT_FALSE(user_sock_->IsConnected());
222 // Test that we fail trying to connect to a hosname longer than 255 bytes.
223 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
224 // Create a string of length 256, where each character is 'x'.
225 std::string large_host_name;
226 std::fill_n(std::back_inserter(large_host_name), 256, 'x');
228 // Create a SOCKS socket, with mock transport socket.
229 MockWrite data_writes[] = {MockWrite()};
230 MockRead data_reads[] = {MockRead()};
231 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
232 data_writes, arraysize(data_writes),
233 large_host_name, 80, NULL);
235 // Try to connect -- should fail (without having read/written anything to
236 // the transport socket first) because the hostname is too long.
237 TestCompletionCallback callback;
238 int rv = user_sock_->Connect(callback.callback());
239 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
242 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
243 const std::string hostname = "www.google.com";
245 const char kOkRequest[] = {
246 0x05, // Version
247 0x01, // Command (CONNECT)
248 0x00, // Reserved.
249 0x03, // Address type (DOMAINNAME).
250 0x0E, // Length of domain (14)
251 // Domain string:
252 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
253 0x00, 0x50, // 16-bit port (80)
256 // Test for partial greet request write
258 const char partial1[] = { 0x05, 0x01 };
259 const char partial2[] = { 0x00 };
260 MockWrite data_writes[] = {
261 MockWrite(ASYNC, partial1, arraysize(partial1)),
262 MockWrite(ASYNC, partial2, arraysize(partial2)),
263 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
264 MockRead data_reads[] = {
265 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
266 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
267 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
268 data_writes, arraysize(data_writes),
269 hostname, 80, &net_log_);
270 int rv = user_sock_->Connect(callback_.callback());
271 EXPECT_EQ(ERR_IO_PENDING, rv);
273 CapturingNetLog::CapturedEntryList net_log_entries;
274 net_log_.GetEntries(&net_log_entries);
275 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
276 NetLog::TYPE_SOCKS5_CONNECT));
278 rv = callback_.WaitForResult();
279 EXPECT_EQ(OK, rv);
280 EXPECT_TRUE(user_sock_->IsConnected());
282 net_log_.GetEntries(&net_log_entries);
283 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
284 NetLog::TYPE_SOCKS5_CONNECT));
287 // Test for partial greet response read
289 const char partial1[] = { 0x05 };
290 const char partial2[] = { 0x00 };
291 MockWrite data_writes[] = {
292 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
293 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
294 MockRead data_reads[] = {
295 MockRead(ASYNC, partial1, arraysize(partial1)),
296 MockRead(ASYNC, partial2, arraysize(partial2)),
297 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
298 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
299 data_writes, arraysize(data_writes),
300 hostname, 80, &net_log_);
301 int rv = user_sock_->Connect(callback_.callback());
302 EXPECT_EQ(ERR_IO_PENDING, rv);
304 CapturingNetLog::CapturedEntryList net_log_entries;
305 net_log_.GetEntries(&net_log_entries);
306 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
307 NetLog::TYPE_SOCKS5_CONNECT));
308 rv = callback_.WaitForResult();
309 EXPECT_EQ(OK, rv);
310 EXPECT_TRUE(user_sock_->IsConnected());
311 net_log_.GetEntries(&net_log_entries);
312 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
313 NetLog::TYPE_SOCKS5_CONNECT));
316 // Test for partial handshake request write.
318 const int kSplitPoint = 3; // Break handshake write into two parts.
319 MockWrite data_writes[] = {
320 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
321 MockWrite(ASYNC, kOkRequest, kSplitPoint),
322 MockWrite(ASYNC, kOkRequest + kSplitPoint,
323 arraysize(kOkRequest) - kSplitPoint)
325 MockRead data_reads[] = {
326 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
327 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
328 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
329 data_writes, arraysize(data_writes),
330 hostname, 80, &net_log_);
331 int rv = user_sock_->Connect(callback_.callback());
332 EXPECT_EQ(ERR_IO_PENDING, rv);
333 CapturingNetLog::CapturedEntryList net_log_entries;
334 net_log_.GetEntries(&net_log_entries);
335 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
336 NetLog::TYPE_SOCKS5_CONNECT));
337 rv = callback_.WaitForResult();
338 EXPECT_EQ(OK, rv);
339 EXPECT_TRUE(user_sock_->IsConnected());
340 net_log_.GetEntries(&net_log_entries);
341 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
342 NetLog::TYPE_SOCKS5_CONNECT));
345 // Test for partial handshake response read
347 const int kSplitPoint = 6; // Break the handshake read into two parts.
348 MockWrite data_writes[] = {
349 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
350 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest))
352 MockRead data_reads[] = {
353 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
354 MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
355 MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
356 kSOCKS5OkResponseLength - kSplitPoint)
359 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
360 data_writes, arraysize(data_writes),
361 hostname, 80, &net_log_);
362 int rv = user_sock_->Connect(callback_.callback());
363 EXPECT_EQ(ERR_IO_PENDING, rv);
364 CapturingNetLog::CapturedEntryList net_log_entries;
365 net_log_.GetEntries(&net_log_entries);
366 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
367 NetLog::TYPE_SOCKS5_CONNECT));
368 rv = callback_.WaitForResult();
369 EXPECT_EQ(OK, rv);
370 EXPECT_TRUE(user_sock_->IsConnected());
371 net_log_.GetEntries(&net_log_entries);
372 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
373 NetLog::TYPE_SOCKS5_CONNECT));
377 } // namespace
379 } // namespace net