1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/pseudotcp_adapter.h"
10 #include "base/bind_helpers.h"
11 #include "base/compiler_specific.h"
12 #include "jingle/glue/thread_wrapper.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
15 #include "net/base/test_completion_callback.h"
16 #include "remoting/protocol/p2p_datagram_socket.h"
17 #include "remoting/protocol/p2p_stream_socket.h"
18 #include "testing/gmock/include/gmock/gmock.h"
19 #include "testing/gtest/include/gtest/gtest.h"
26 const int kMessageSize
= 1024;
27 const int kMessages
= 100;
28 const int kTestDataSize
= kMessages
* kMessageSize
;
32 virtual ~RateLimiter() { };
33 // Returns true if the new packet needs to be dropped, false otherwise.
34 virtual bool DropNextPacket() = 0;
37 class LeakyBucket
: public RateLimiter
{
39 // |rate| is in drops per second.
40 LeakyBucket(double volume
, double rate
)
44 last_update_(base::TimeTicks::Now()) {
47 ~LeakyBucket() override
{}
49 bool DropNextPacket() override
{
50 base::TimeTicks now
= base::TimeTicks::Now();
51 double interval
= (now
- last_update_
).InSecondsF();
53 level_
= level_
+ 1.0 - interval
* rate_
;
54 if (level_
> volume_
) {
57 } else if (level_
< 0.0) {
67 base::TimeTicks last_update_
;
70 class FakeSocket
: public P2PDatagramSocket
{
73 : rate_limiter_(NULL
),
76 ~FakeSocket() override
{}
78 void AppendInputPacket(const std::vector
<char>& data
) {
79 if (rate_limiter_
&& rate_limiter_
->DropNextPacket())
80 return; // Lose the packet.
82 if (!read_callback_
.is_null()) {
83 int size
= std::min(read_buffer_size_
, static_cast<int>(data
.size()));
84 memcpy(read_buffer_
->data(), &data
[0], data
.size());
85 net::CompletionCallback cb
= read_callback_
;
86 read_callback_
.Reset();
90 incoming_packets_
.push_back(data
);
94 void Connect(FakeSocket
* peer_socket
) {
95 peer_socket_
= peer_socket
;
98 void set_rate_limiter(RateLimiter
* rate_limiter
) {
99 rate_limiter_
= rate_limiter
;
102 void set_latency(int latency_ms
) { latency_ms_
= latency_ms
; };
104 // P2PDatagramSocket interface.
105 int Recv(const scoped_refptr
<net::IOBuffer
>& buf
, int buf_len
,
106 const net::CompletionCallback
& callback
) override
{
107 CHECK(read_callback_
.is_null());
110 if (incoming_packets_
.size() > 0) {
111 scoped_refptr
<net::IOBuffer
> buffer(buf
);
113 static_cast<int>(incoming_packets_
.front().size()), buf_len
);
114 memcpy(buffer
->data(), &*incoming_packets_
.front().begin(), size
);
115 incoming_packets_
.pop_front();
118 read_callback_
= callback
;
120 read_buffer_size_
= buf_len
;
121 return net::ERR_IO_PENDING
;
125 int Send(const scoped_refptr
<net::IOBuffer
>& buf
, int buf_len
,
126 const net::CompletionCallback
& callback
) override
{
129 base::MessageLoop::current()->PostDelayedTask(
131 base::Bind(&FakeSocket::AppendInputPacket
,
132 base::Unretained(peer_socket_
),
133 std::vector
<char>(buf
->data(), buf
->data() + buf_len
)),
134 base::TimeDelta::FromMilliseconds(latency_ms_
));
141 scoped_refptr
<net::IOBuffer
> read_buffer_
;
142 int read_buffer_size_
;
143 net::CompletionCallback read_callback_
;
145 std::deque
<std::vector
<char> > incoming_packets_
;
147 FakeSocket
* peer_socket_
;
148 RateLimiter
* rate_limiter_
;
152 class TCPChannelTester
: public base::RefCountedThreadSafe
<TCPChannelTester
> {
154 TCPChannelTester(base::MessageLoop
* message_loop
,
155 P2PStreamSocket
* client_socket
,
156 P2PStreamSocket
* host_socket
)
157 : message_loop_(message_loop
),
158 host_socket_(host_socket
),
159 client_socket_(client_socket
),
165 message_loop_
->PostTask(
166 FROM_HERE
, base::Bind(&TCPChannelTester::DoStart
, this));
169 void CheckResults() {
170 EXPECT_EQ(0, write_errors_
);
171 EXPECT_EQ(0, read_errors_
);
173 ASSERT_EQ(kTestDataSize
+ kMessageSize
, input_buffer_
->capacity());
175 output_buffer_
->SetOffset(0);
176 ASSERT_EQ(kTestDataSize
, output_buffer_
->size());
178 EXPECT_EQ(0, memcmp(output_buffer_
->data(),
179 input_buffer_
->StartOfBuffer(), kTestDataSize
));
183 virtual ~TCPChannelTester() {}
187 message_loop_
->PostTask(FROM_HERE
, base::MessageLoop::QuitClosure());
197 output_buffer_
= new net::DrainableIOBuffer(
198 new net::IOBuffer(kTestDataSize
), kTestDataSize
);
199 memset(output_buffer_
->data(), 123, kTestDataSize
);
201 input_buffer_
= new net::GrowableIOBuffer();
202 // Always keep kMessageSize bytes available at the end of the input buffer.
203 input_buffer_
->SetCapacity(kMessageSize
);
209 if (output_buffer_
->BytesRemaining() == 0)
212 int bytes_to_write
= std::min(output_buffer_
->BytesRemaining(),
214 result
= client_socket_
->Write(
215 output_buffer_
.get(),
217 base::Bind(&TCPChannelTester::OnWritten
, base::Unretained(this)));
218 HandleWriteResult(result
);
222 void OnWritten(int result
) {
223 HandleWriteResult(result
);
227 void HandleWriteResult(int result
) {
228 if (result
<= 0 && result
!= net::ERR_IO_PENDING
) {
229 LOG(ERROR
) << "Received error " << result
<< " when trying to write";
232 } else if (result
> 0) {
233 output_buffer_
->DidConsume(result
);
240 input_buffer_
->set_offset(input_buffer_
->capacity() - kMessageSize
);
242 result
= host_socket_
->Read(
245 base::Bind(&TCPChannelTester::OnRead
, base::Unretained(this)));
246 HandleReadResult(result
);
250 void OnRead(int result
) {
251 HandleReadResult(result
);
255 void HandleReadResult(int result
) {
256 if (result
<= 0 && result
!= net::ERR_IO_PENDING
) {
258 LOG(ERROR
) << "Received error " << result
<< " when trying to read";
262 } else if (result
> 0) {
263 // Allocate memory for the next read.
264 input_buffer_
->SetCapacity(input_buffer_
->capacity() + result
);
265 if (input_buffer_
->capacity() == kTestDataSize
+ kMessageSize
)
271 friend class base::RefCountedThreadSafe
<TCPChannelTester
>;
273 base::MessageLoop
* message_loop_
;
274 P2PStreamSocket
* host_socket_
;
275 P2PStreamSocket
* client_socket_
;
278 scoped_refptr
<net::DrainableIOBuffer
> output_buffer_
;
279 scoped_refptr
<net::GrowableIOBuffer
> input_buffer_
;
285 class PseudoTcpAdapterTest
: public testing::Test
{
287 void SetUp() override
{
288 jingle_glue::JingleThreadWrapper::EnsureForCurrentMessageLoop();
290 host_socket_
= new FakeSocket();
291 client_socket_
= new FakeSocket();
293 host_socket_
->Connect(client_socket_
);
294 client_socket_
->Connect(host_socket_
);
296 host_pseudotcp_
.reset(new PseudoTcpAdapter(make_scoped_ptr(host_socket_
)));
297 client_pseudotcp_
.reset(
298 new PseudoTcpAdapter(make_scoped_ptr(client_socket_
)));
301 FakeSocket
* host_socket_
;
302 FakeSocket
* client_socket_
;
304 scoped_ptr
<PseudoTcpAdapter
> host_pseudotcp_
;
305 scoped_ptr
<PseudoTcpAdapter
> client_pseudotcp_
;
306 base::MessageLoop message_loop_
;
309 TEST_F(PseudoTcpAdapterTest
, DataTransfer
) {
310 net::TestCompletionCallback host_connect_cb
;
311 net::TestCompletionCallback client_connect_cb
;
313 int rv1
= host_pseudotcp_
->Connect(host_connect_cb
.callback());
314 int rv2
= client_pseudotcp_
->Connect(client_connect_cb
.callback());
316 if (rv1
== net::ERR_IO_PENDING
)
317 rv1
= host_connect_cb
.WaitForResult();
318 if (rv2
== net::ERR_IO_PENDING
)
319 rv2
= client_connect_cb
.WaitForResult();
320 ASSERT_EQ(net::OK
, rv1
);
321 ASSERT_EQ(net::OK
, rv2
);
323 scoped_refptr
<TCPChannelTester
> tester
=
324 new TCPChannelTester(&message_loop_
, host_pseudotcp_
.get(),
325 client_pseudotcp_
.get());
329 tester
->CheckResults();
332 TEST_F(PseudoTcpAdapterTest
, LimitedChannel
) {
333 const int kLatencyMs
= 20;
334 const int kPacketsPerSecond
= 400;
335 const int kBurstPackets
= 10;
337 LeakyBucket
host_limiter(kBurstPackets
, kPacketsPerSecond
);
338 host_socket_
->set_latency(kLatencyMs
);
339 host_socket_
->set_rate_limiter(&host_limiter
);
341 LeakyBucket
client_limiter(kBurstPackets
, kPacketsPerSecond
);
342 host_socket_
->set_latency(kLatencyMs
);
343 client_socket_
->set_rate_limiter(&client_limiter
);
345 net::TestCompletionCallback host_connect_cb
;
346 net::TestCompletionCallback client_connect_cb
;
348 int rv1
= host_pseudotcp_
->Connect(host_connect_cb
.callback());
349 int rv2
= client_pseudotcp_
->Connect(client_connect_cb
.callback());
351 if (rv1
== net::ERR_IO_PENDING
)
352 rv1
= host_connect_cb
.WaitForResult();
353 if (rv2
== net::ERR_IO_PENDING
)
354 rv2
= client_connect_cb
.WaitForResult();
355 ASSERT_EQ(net::OK
, rv1
);
356 ASSERT_EQ(net::OK
, rv2
);
358 scoped_refptr
<TCPChannelTester
> tester
=
359 new TCPChannelTester(&message_loop_
, host_pseudotcp_
.get(),
360 client_pseudotcp_
.get());
364 tester
->CheckResults();
367 class DeleteOnConnected
{
369 DeleteOnConnected(base::MessageLoop
* message_loop
,
370 scoped_ptr
<PseudoTcpAdapter
>* adapter
)
371 : message_loop_(message_loop
), adapter_(adapter
) {}
372 void OnConnected(int error
) {
374 message_loop_
->PostTask(FROM_HERE
, base::MessageLoop::QuitClosure());
376 base::MessageLoop
* message_loop_
;
377 scoped_ptr
<PseudoTcpAdapter
>* adapter_
;
380 TEST_F(PseudoTcpAdapterTest
, DeleteOnConnected
) {
381 // This test verifies that deleting the adapter mid-callback doesn't lead
382 // to deleted structures being touched as the stack unrolls, so the failure
383 // mode is a crash rather than a normal test failure.
384 net::TestCompletionCallback client_connect_cb
;
385 DeleteOnConnected
host_delete(&message_loop_
, &host_pseudotcp_
);
387 host_pseudotcp_
->Connect(base::Bind(&DeleteOnConnected::OnConnected
,
388 base::Unretained(&host_delete
)));
389 client_pseudotcp_
->Connect(client_connect_cb
.callback());
392 ASSERT_EQ(NULL
, host_pseudotcp_
.get());
395 // Verify that we can send/receive data with the write-waits-for-send
397 TEST_F(PseudoTcpAdapterTest
, WriteWaitsForSendLetsDataThrough
) {
398 net::TestCompletionCallback host_connect_cb
;
399 net::TestCompletionCallback client_connect_cb
;
401 host_pseudotcp_
->SetWriteWaitsForSend(true);
402 client_pseudotcp_
->SetWriteWaitsForSend(true);
404 // Disable Nagle's algorithm because the test is slow when it is
406 host_pseudotcp_
->SetNoDelay(true);
408 int rv1
= host_pseudotcp_
->Connect(host_connect_cb
.callback());
409 int rv2
= client_pseudotcp_
->Connect(client_connect_cb
.callback());
411 if (rv1
== net::ERR_IO_PENDING
)
412 rv1
= host_connect_cb
.WaitForResult();
413 if (rv2
== net::ERR_IO_PENDING
)
414 rv2
= client_connect_cb
.WaitForResult();
415 ASSERT_EQ(net::OK
, rv1
);
416 ASSERT_EQ(net::OK
, rv2
);
418 scoped_refptr
<TCPChannelTester
> tester
=
419 new TCPChannelTester(&message_loop_
, host_pseudotcp_
.get(),
420 client_pseudotcp_
.get());
424 tester
->CheckResults();
429 } // namespace protocol
430 } // namespace remoting