1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks5_client_socket.h"
11 #include "base/sys_byteorder.h"
12 #include "net/base/address_list.h"
13 #include "net/base/net_log.h"
14 #include "net/base/net_log_unittest.h"
15 #include "net/base/test_completion_callback.h"
16 #include "net/base/winsock_init.h"
17 #include "net/dns/mock_host_resolver.h"
18 #include "net/socket/client_socket_factory.h"
19 #include "net/socket/socket_test_util.h"
20 #include "net/socket/tcp_client_socket.h"
21 #include "testing/gtest/include/gtest/gtest.h"
22 #include "testing/platform_test.h"
24 //-----------------------------------------------------------------------------
30 // Base class to test SOCKS5ClientSocket
31 class SOCKS5ClientSocketTest
: public PlatformTest
{
33 SOCKS5ClientSocketTest();
34 // Create a SOCKSClientSocket on top of a MockSocket.
35 scoped_ptr
<SOCKS5ClientSocket
> BuildMockSocket(MockRead reads
[],
39 const std::string
& hostname
,
43 void SetUp() override
;
47 CapturingNetLog net_log_
;
48 scoped_ptr
<SOCKS5ClientSocket
> user_sock_
;
49 AddressList address_list_
;
50 // Filled in by BuildMockSocket() and owned by its return value
51 // (which |user_sock| is set to).
52 StreamSocket
* tcp_sock_
;
53 TestCompletionCallback callback_
;
54 scoped_ptr
<MockHostResolver
> host_resolver_
;
55 scoped_ptr
<SocketDataProvider
> data_
;
58 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest
);
61 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
62 : kNwPort(base::HostToNet16(80)),
63 host_resolver_(new MockHostResolver
) {
66 // Set up platform before every test case
67 void SOCKS5ClientSocketTest::SetUp() {
68 PlatformTest::SetUp();
70 // Resolve the "localhost" AddressList used by the TCP connection to connect.
71 HostResolver::RequestInfo
info(HostPortPair("www.socks-proxy.com", 1080));
72 TestCompletionCallback callback
;
73 int rv
= host_resolver_
->Resolve(info
,
79 ASSERT_EQ(ERR_IO_PENDING
, rv
);
80 rv
= callback
.WaitForResult();
84 scoped_ptr
<SOCKS5ClientSocket
> SOCKS5ClientSocketTest::BuildMockSocket(
89 const std::string
& hostname
,
92 TestCompletionCallback callback
;
93 data_
.reset(new StaticSocketDataProvider(reads
, reads_count
,
94 writes
, writes_count
));
95 tcp_sock_
= new MockTCPClientSocket(address_list_
, net_log
, data_
.get());
97 int rv
= tcp_sock_
->Connect(callback
.callback());
98 EXPECT_EQ(ERR_IO_PENDING
, rv
);
99 rv
= callback
.WaitForResult();
101 EXPECT_TRUE(tcp_sock_
->IsConnected());
103 scoped_ptr
<ClientSocketHandle
> connection(new ClientSocketHandle
);
104 // |connection| takes ownership of |tcp_sock_|, but keep a
105 // non-owning pointer to it.
106 connection
->SetSocket(scoped_ptr
<StreamSocket
>(tcp_sock_
));
107 return scoped_ptr
<SOCKS5ClientSocket
>(new SOCKS5ClientSocket(
109 HostResolver::RequestInfo(HostPortPair(hostname
, port
))));
112 // Tests a complete SOCKS5 handshake and the disconnection.
113 TEST_F(SOCKS5ClientSocketTest
, CompleteHandshake
) {
114 const std::string payload_write
= "random data";
115 const std::string payload_read
= "moar random data";
117 const char kOkRequest
[] = {
119 0x01, // Command (CONNECT)
121 0x03, // Address type (DOMAINNAME).
122 0x09, // Length of domain (9)
124 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
125 0x00, 0x50, // 16-bit port (80)
128 MockWrite data_writes
[] = {
129 MockWrite(ASYNC
, kSOCKS5GreetRequest
, kSOCKS5GreetRequestLength
),
130 MockWrite(ASYNC
, kOkRequest
, arraysize(kOkRequest
)),
131 MockWrite(ASYNC
, payload_write
.data(), payload_write
.size()) };
132 MockRead data_reads
[] = {
133 MockRead(ASYNC
, kSOCKS5GreetResponse
, kSOCKS5GreetResponseLength
),
134 MockRead(ASYNC
, kSOCKS5OkResponse
, kSOCKS5OkResponseLength
),
135 MockRead(ASYNC
, payload_read
.data(), payload_read
.size()) };
137 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
138 data_writes
, arraysize(data_writes
),
139 "localhost", 80, &net_log_
);
141 // At this state the TCP connection is completed but not the SOCKS handshake.
142 EXPECT_TRUE(tcp_sock_
->IsConnected());
143 EXPECT_FALSE(user_sock_
->IsConnected());
145 int rv
= user_sock_
->Connect(callback_
.callback());
146 EXPECT_EQ(ERR_IO_PENDING
, rv
);
147 EXPECT_FALSE(user_sock_
->IsConnected());
149 CapturingNetLog::CapturedEntryList net_log_entries
;
150 net_log_
.GetEntries(&net_log_entries
);
151 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries
, 0,
152 NetLog::TYPE_SOCKS5_CONNECT
));
154 rv
= callback_
.WaitForResult();
157 EXPECT_TRUE(user_sock_
->IsConnected());
159 net_log_
.GetEntries(&net_log_entries
);
160 EXPECT_TRUE(LogContainsEndEvent(net_log_entries
, -1,
161 NetLog::TYPE_SOCKS5_CONNECT
));
163 scoped_refptr
<IOBuffer
> buffer(new IOBuffer(payload_write
.size()));
164 memcpy(buffer
->data(), payload_write
.data(), payload_write
.size());
165 rv
= user_sock_
->Write(
166 buffer
.get(), payload_write
.size(), callback_
.callback());
167 EXPECT_EQ(ERR_IO_PENDING
, rv
);
168 rv
= callback_
.WaitForResult();
169 EXPECT_EQ(static_cast<int>(payload_write
.size()), rv
);
171 buffer
= new IOBuffer(payload_read
.size());
173 user_sock_
->Read(buffer
.get(), payload_read
.size(), callback_
.callback());
174 EXPECT_EQ(ERR_IO_PENDING
, rv
);
175 rv
= callback_
.WaitForResult();
176 EXPECT_EQ(static_cast<int>(payload_read
.size()), rv
);
177 EXPECT_EQ(payload_read
, std::string(buffer
->data(), payload_read
.size()));
179 user_sock_
->Disconnect();
180 EXPECT_FALSE(tcp_sock_
->IsConnected());
181 EXPECT_FALSE(user_sock_
->IsConnected());
184 // Test that you can call Connect() again after having called Disconnect().
185 TEST_F(SOCKS5ClientSocketTest
, ConnectAndDisconnectTwice
) {
186 const std::string hostname
= "my-host-name";
187 const char kSOCKS5DomainRequest
[] = {
194 std::string
request(kSOCKS5DomainRequest
, arraysize(kSOCKS5DomainRequest
));
195 request
.push_back(hostname
.size());
196 request
.append(hostname
);
197 request
.append(reinterpret_cast<const char*>(&kNwPort
), sizeof(kNwPort
));
199 for (int i
= 0; i
< 2; ++i
) {
200 MockWrite data_writes
[] = {
201 MockWrite(SYNCHRONOUS
, kSOCKS5GreetRequest
, kSOCKS5GreetRequestLength
),
202 MockWrite(SYNCHRONOUS
, request
.data(), request
.size())
204 MockRead data_reads
[] = {
205 MockRead(SYNCHRONOUS
, kSOCKS5GreetResponse
, kSOCKS5GreetResponseLength
),
206 MockRead(SYNCHRONOUS
, kSOCKS5OkResponse
, kSOCKS5OkResponseLength
)
209 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
210 data_writes
, arraysize(data_writes
),
213 int rv
= user_sock_
->Connect(callback_
.callback());
215 EXPECT_TRUE(user_sock_
->IsConnected());
217 user_sock_
->Disconnect();
218 EXPECT_FALSE(user_sock_
->IsConnected());
222 // Test that we fail trying to connect to a hosname longer than 255 bytes.
223 TEST_F(SOCKS5ClientSocketTest
, LargeHostNameFails
) {
224 // Create a string of length 256, where each character is 'x'.
225 std::string large_host_name
;
226 std::fill_n(std::back_inserter(large_host_name
), 256, 'x');
228 // Create a SOCKS socket, with mock transport socket.
229 MockWrite data_writes
[] = {MockWrite()};
230 MockRead data_reads
[] = {MockRead()};
231 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
232 data_writes
, arraysize(data_writes
),
233 large_host_name
, 80, NULL
);
235 // Try to connect -- should fail (without having read/written anything to
236 // the transport socket first) because the hostname is too long.
237 TestCompletionCallback callback
;
238 int rv
= user_sock_
->Connect(callback
.callback());
239 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED
, rv
);
242 TEST_F(SOCKS5ClientSocketTest
, PartialReadWrites
) {
243 const std::string hostname
= "www.google.com";
245 const char kOkRequest
[] = {
247 0x01, // Command (CONNECT)
249 0x03, // Address type (DOMAINNAME).
250 0x0E, // Length of domain (14)
252 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
253 0x00, 0x50, // 16-bit port (80)
256 // Test for partial greet request write
258 const char partial1
[] = { 0x05, 0x01 };
259 const char partial2
[] = { 0x00 };
260 MockWrite data_writes
[] = {
261 MockWrite(ASYNC
, partial1
, arraysize(partial1
)),
262 MockWrite(ASYNC
, partial2
, arraysize(partial2
)),
263 MockWrite(ASYNC
, kOkRequest
, arraysize(kOkRequest
)) };
264 MockRead data_reads
[] = {
265 MockRead(ASYNC
, kSOCKS5GreetResponse
, kSOCKS5GreetResponseLength
),
266 MockRead(ASYNC
, kSOCKS5OkResponse
, kSOCKS5OkResponseLength
) };
267 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
268 data_writes
, arraysize(data_writes
),
269 hostname
, 80, &net_log_
);
270 int rv
= user_sock_
->Connect(callback_
.callback());
271 EXPECT_EQ(ERR_IO_PENDING
, rv
);
273 CapturingNetLog::CapturedEntryList net_log_entries
;
274 net_log_
.GetEntries(&net_log_entries
);
275 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries
, 0,
276 NetLog::TYPE_SOCKS5_CONNECT
));
278 rv
= callback_
.WaitForResult();
280 EXPECT_TRUE(user_sock_
->IsConnected());
282 net_log_
.GetEntries(&net_log_entries
);
283 EXPECT_TRUE(LogContainsEndEvent(net_log_entries
, -1,
284 NetLog::TYPE_SOCKS5_CONNECT
));
287 // Test for partial greet response read
289 const char partial1
[] = { 0x05 };
290 const char partial2
[] = { 0x00 };
291 MockWrite data_writes
[] = {
292 MockWrite(ASYNC
, kSOCKS5GreetRequest
, kSOCKS5GreetRequestLength
),
293 MockWrite(ASYNC
, kOkRequest
, arraysize(kOkRequest
)) };
294 MockRead data_reads
[] = {
295 MockRead(ASYNC
, partial1
, arraysize(partial1
)),
296 MockRead(ASYNC
, partial2
, arraysize(partial2
)),
297 MockRead(ASYNC
, kSOCKS5OkResponse
, kSOCKS5OkResponseLength
) };
298 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
299 data_writes
, arraysize(data_writes
),
300 hostname
, 80, &net_log_
);
301 int rv
= user_sock_
->Connect(callback_
.callback());
302 EXPECT_EQ(ERR_IO_PENDING
, rv
);
304 CapturingNetLog::CapturedEntryList net_log_entries
;
305 net_log_
.GetEntries(&net_log_entries
);
306 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries
, 0,
307 NetLog::TYPE_SOCKS5_CONNECT
));
308 rv
= callback_
.WaitForResult();
310 EXPECT_TRUE(user_sock_
->IsConnected());
311 net_log_
.GetEntries(&net_log_entries
);
312 EXPECT_TRUE(LogContainsEndEvent(net_log_entries
, -1,
313 NetLog::TYPE_SOCKS5_CONNECT
));
316 // Test for partial handshake request write.
318 const int kSplitPoint
= 3; // Break handshake write into two parts.
319 MockWrite data_writes
[] = {
320 MockWrite(ASYNC
, kSOCKS5GreetRequest
, kSOCKS5GreetRequestLength
),
321 MockWrite(ASYNC
, kOkRequest
, kSplitPoint
),
322 MockWrite(ASYNC
, kOkRequest
+ kSplitPoint
,
323 arraysize(kOkRequest
) - kSplitPoint
)
325 MockRead data_reads
[] = {
326 MockRead(ASYNC
, kSOCKS5GreetResponse
, kSOCKS5GreetResponseLength
),
327 MockRead(ASYNC
, kSOCKS5OkResponse
, kSOCKS5OkResponseLength
) };
328 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
329 data_writes
, arraysize(data_writes
),
330 hostname
, 80, &net_log_
);
331 int rv
= user_sock_
->Connect(callback_
.callback());
332 EXPECT_EQ(ERR_IO_PENDING
, rv
);
333 CapturingNetLog::CapturedEntryList net_log_entries
;
334 net_log_
.GetEntries(&net_log_entries
);
335 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries
, 0,
336 NetLog::TYPE_SOCKS5_CONNECT
));
337 rv
= callback_
.WaitForResult();
339 EXPECT_TRUE(user_sock_
->IsConnected());
340 net_log_
.GetEntries(&net_log_entries
);
341 EXPECT_TRUE(LogContainsEndEvent(net_log_entries
, -1,
342 NetLog::TYPE_SOCKS5_CONNECT
));
345 // Test for partial handshake response read
347 const int kSplitPoint
= 6; // Break the handshake read into two parts.
348 MockWrite data_writes
[] = {
349 MockWrite(ASYNC
, kSOCKS5GreetRequest
, kSOCKS5GreetRequestLength
),
350 MockWrite(ASYNC
, kOkRequest
, arraysize(kOkRequest
))
352 MockRead data_reads
[] = {
353 MockRead(ASYNC
, kSOCKS5GreetResponse
, kSOCKS5GreetResponseLength
),
354 MockRead(ASYNC
, kSOCKS5OkResponse
, kSplitPoint
),
355 MockRead(ASYNC
, kSOCKS5OkResponse
+ kSplitPoint
,
356 kSOCKS5OkResponseLength
- kSplitPoint
)
359 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
360 data_writes
, arraysize(data_writes
),
361 hostname
, 80, &net_log_
);
362 int rv
= user_sock_
->Connect(callback_
.callback());
363 EXPECT_EQ(ERR_IO_PENDING
, rv
);
364 CapturingNetLog::CapturedEntryList net_log_entries
;
365 net_log_
.GetEntries(&net_log_entries
);
366 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries
, 0,
367 NetLog::TYPE_SOCKS5_CONNECT
));
368 rv
= callback_
.WaitForResult();
370 EXPECT_TRUE(user_sock_
->IsConnected());
371 net_log_
.GetEntries(&net_log_entries
);
372 EXPECT_TRUE(LogContainsEndEvent(net_log_entries
, -1,
373 NetLog::TYPE_SOCKS5_CONNECT
));