1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/util.h"
23 const int kChannelIdUnknown
= -1;
24 const int kMaxPacketSize
= 1024;
28 PendingPacket(scoped_ptr
<MultiplexPacket
> packet
,
29 const base::Closure
& done_task
)
30 : packet(packet
.Pass()),
38 bool is_empty() { return pos
>= packet
->data().size(); }
40 int Read(char* buffer
, size_t size
) {
41 size
= std::min(size
, packet
->data().size() - pos
);
42 memcpy(buffer
, packet
->data().data() + pos
, size
);
48 scoped_ptr
<MultiplexPacket
> packet
;
49 base::Closure done_task
;
52 DISALLOW_COPY_AND_ASSIGN(PendingPacket
);
57 const char ChannelMultiplexer::kMuxChannelName
[] = "mux";
59 struct ChannelMultiplexer::PendingChannel
{
60 PendingChannel(const std::string
& name
,
61 const StreamChannelCallback
& callback
)
62 : name(name
), callback(callback
) {
65 StreamChannelCallback callback
;
68 class ChannelMultiplexer::MuxChannel
{
70 MuxChannel(ChannelMultiplexer
* multiplexer
, const std::string
& name
,
74 const std::string
& name() { return name_
; }
75 int receive_id() { return receive_id_
; }
76 void set_receive_id(int id
) { receive_id_
= id
; }
78 // Called by ChannelMultiplexer.
79 scoped_ptr
<net::StreamSocket
> CreateSocket();
80 void OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
81 const base::Closure
& done_task
);
84 // Called by MuxSocket.
85 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
87 const base::Closure
& done_task
);
88 int DoRead(net::IOBuffer
* buffer
, int buffer_len
);
91 ChannelMultiplexer
* multiplexer_
;
97 std::list
<PendingPacket
*> pending_packets_
;
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel
);
102 class ChannelMultiplexer::MuxSocket
: public net::StreamSocket
,
103 public base::NonThreadSafe
,
104 public base::SupportsWeakPtr
<MuxSocket
> {
106 MuxSocket(MuxChannel
* channel
);
107 virtual ~MuxSocket();
109 void OnWriteComplete();
110 void OnWriteFailed();
111 void OnPacketReceived();
113 // net::StreamSocket interface.
114 virtual int Read(net::IOBuffer
* buffer
, int buffer_len
,
115 const net::CompletionCallback
& callback
) OVERRIDE
;
116 virtual int Write(net::IOBuffer
* buffer
, int buffer_len
,
117 const net::CompletionCallback
& callback
) OVERRIDE
;
119 virtual bool SetReceiveBufferSize(int32 size
) OVERRIDE
{
123 virtual bool SetSendBufferSize(int32 size
) OVERRIDE
{
128 virtual int Connect(const net::CompletionCallback
& callback
) OVERRIDE
{
130 return net::ERR_FAILED
;
132 virtual void Disconnect() OVERRIDE
{
135 virtual bool IsConnected() const OVERRIDE
{
139 virtual bool IsConnectedAndIdle() const OVERRIDE
{
143 virtual int GetPeerAddress(net::IPEndPoint
* address
) const OVERRIDE
{
145 return net::ERR_FAILED
;
147 virtual int GetLocalAddress(net::IPEndPoint
* address
) const OVERRIDE
{
149 return net::ERR_FAILED
;
151 virtual const net::BoundNetLog
& NetLog() const OVERRIDE
{
155 virtual void SetSubresourceSpeculation() OVERRIDE
{
158 virtual void SetOmniboxSpeculation() OVERRIDE
{
161 virtual bool WasEverUsed() const OVERRIDE
{
164 virtual bool UsingTCPFastOpen() const OVERRIDE
{
167 virtual bool WasNpnNegotiated() const OVERRIDE
{
170 virtual net::NextProto
GetNegotiatedProtocol() const OVERRIDE
{
171 return net::kProtoUnknown
;
173 virtual bool GetSSLInfo(net::SSLInfo
* ssl_info
) OVERRIDE
{
179 MuxChannel
* channel_
;
181 net::CompletionCallback read_callback_
;
182 scoped_refptr
<net::IOBuffer
> read_buffer_
;
183 int read_buffer_size_
;
187 net::CompletionCallback write_callback_
;
189 net::BoundNetLog net_log_
;
191 DISALLOW_COPY_AND_ASSIGN(MuxSocket
);
195 ChannelMultiplexer::MuxChannel::MuxChannel(
196 ChannelMultiplexer
* multiplexer
,
197 const std::string
& name
,
199 : multiplexer_(multiplexer
),
203 receive_id_(kChannelIdUnknown
),
207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
208 // Socket must be destroyed before the channel.
210 STLDeleteElements(&pending_packets_
);
213 scoped_ptr
<net::StreamSocket
> ChannelMultiplexer::MuxChannel::CreateSocket() {
214 DCHECK(!socket_
); // Can't create more than one socket per channel.
215 scoped_ptr
<MuxSocket
> result(new MuxSocket(this));
216 socket_
= result
.get();
217 return result
.PassAs
<net::StreamSocket
>();
220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221 scoped_ptr
<MultiplexPacket
> packet
,
222 const base::Closure
& done_task
) {
223 DCHECK_EQ(packet
->channel_id(), receive_id_
);
224 if (packet
->data().size() > 0) {
225 pending_packets_
.push_back(new PendingPacket(packet
.Pass(), done_task
));
227 // Notify the socket that we have more data.
228 socket_
->OnPacketReceived();
233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
235 socket_
->OnWriteFailed();
238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
243 bool ChannelMultiplexer::MuxChannel::DoWrite(
244 scoped_ptr
<MultiplexPacket
> packet
,
245 const base::Closure
& done_task
) {
246 packet
->set_channel_id(send_id_
);
248 packet
->set_channel_name(name_
);
251 return multiplexer_
->DoWrite(packet
.Pass(), done_task
);
254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer
* buffer
,
257 while (buffer_len
> 0 && !pending_packets_
.empty()) {
258 DCHECK(!pending_packets_
.front()->is_empty());
259 int result
= pending_packets_
.front()->Read(
260 buffer
->data() + pos
, buffer_len
);
261 DCHECK_LE(result
, buffer_len
);
264 if (pending_packets_
.front()->is_empty()) {
265 delete pending_packets_
.front();
266 pending_packets_
.erase(pending_packets_
.begin());
272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel
* channel
)
274 read_buffer_size_(0),
275 write_pending_(false),
279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
280 channel_
->OnSocketDestroyed();
283 int ChannelMultiplexer::MuxSocket::Read(
284 net::IOBuffer
* buffer
, int buffer_len
,
285 const net::CompletionCallback
& callback
) {
286 DCHECK(CalledOnValidThread());
287 DCHECK(read_callback_
.is_null());
289 int result
= channel_
->DoRead(buffer
, buffer_len
);
291 read_buffer_
= buffer
;
292 read_buffer_size_
= buffer_len
;
293 read_callback_
= callback
;
294 return net::ERR_IO_PENDING
;
299 int ChannelMultiplexer::MuxSocket::Write(
300 net::IOBuffer
* buffer
, int buffer_len
,
301 const net::CompletionCallback
& callback
) {
302 DCHECK(CalledOnValidThread());
304 scoped_ptr
<MultiplexPacket
> packet(new MultiplexPacket());
305 size_t size
= std::min(kMaxPacketSize
, buffer_len
);
306 packet
->mutable_data()->assign(buffer
->data(), size
);
308 write_pending_
= true;
309 bool result
= channel_
->DoWrite(packet
.Pass(), base::Bind(
310 &ChannelMultiplexer::MuxSocket::OnWriteComplete
, AsWeakPtr()));
313 // Cannot complete the write, e.g. if the connection has been terminated.
314 return net::ERR_FAILED
;
317 // OnWriteComplete() might be called above synchronously.
318 if (write_pending_
) {
319 DCHECK(write_callback_
.is_null());
320 write_callback_
= callback
;
321 write_result_
= size
;
322 return net::ERR_IO_PENDING
;
328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329 write_pending_
= false;
330 if (!write_callback_
.is_null()) {
331 net::CompletionCallback cb
;
332 std::swap(cb
, write_callback_
);
333 cb
.Run(write_result_
);
337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338 if (!write_callback_
.is_null()) {
339 net::CompletionCallback cb
;
340 std::swap(cb
, write_callback_
);
341 cb
.Run(net::ERR_FAILED
);
345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346 if (!read_callback_
.is_null()) {
347 int result
= channel_
->DoRead(read_buffer_
.get(), read_buffer_size_
);
349 DCHECK_GT(result
, 0);
350 net::CompletionCallback cb
;
351 std::swap(cb
, read_callback_
);
356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory
* factory
,
357 const std::string
& base_channel_name
)
358 : base_channel_factory_(factory
),
359 base_channel_name_(base_channel_name
),
361 weak_factory_(this) {
364 ChannelMultiplexer::~ChannelMultiplexer() {
365 DCHECK(pending_channels_
.empty());
366 STLDeleteValues(&channels_
);
368 // Cancel creation of the base channel if it hasn't finished.
369 if (base_channel_factory_
)
370 base_channel_factory_
->CancelChannelCreation(base_channel_name_
);
373 void ChannelMultiplexer::CreateStreamChannel(
374 const std::string
& name
,
375 const StreamChannelCallback
& callback
) {
376 if (base_channel_
.get()) {
377 // Already have |base_channel_|. Create new multiplexed channel
379 callback
.Run(GetOrCreateChannel(name
)->CreateSocket());
380 } else if (!base_channel_
.get() && !base_channel_factory_
) {
381 // Fail synchronously if we failed to create |base_channel_|.
382 callback
.Run(scoped_ptr
<net::StreamSocket
>());
384 // Still waiting for the |base_channel_|.
385 pending_channels_
.push_back(PendingChannel(name
, callback
));
387 // If this is the first multiplexed channel then create the base channel.
388 if (pending_channels_
.size() == 1U) {
389 base_channel_factory_
->CreateStreamChannel(
391 base::Bind(&ChannelMultiplexer::OnBaseChannelReady
,
392 base::Unretained(this)));
397 void ChannelMultiplexer::CreateDatagramChannel(
398 const std::string
& name
,
399 const DatagramChannelCallback
& callback
) {
401 callback
.Run(scoped_ptr
<net::Socket
>());
404 void ChannelMultiplexer::CancelChannelCreation(const std::string
& name
) {
405 for (std::list
<PendingChannel
>::iterator it
= pending_channels_
.begin();
406 it
!= pending_channels_
.end(); ++it
) {
407 if (it
->name
== name
) {
408 pending_channels_
.erase(it
);
414 void ChannelMultiplexer::OnBaseChannelReady(
415 scoped_ptr
<net::StreamSocket
> socket
) {
416 base_channel_factory_
= NULL
;
417 base_channel_
= socket
.Pass();
419 if (base_channel_
.get()) {
420 // Initialize reader and writer.
421 reader_
.Init(base_channel_
.get(),
422 base::Bind(&ChannelMultiplexer::OnIncomingPacket
,
423 base::Unretained(this)));
424 writer_
.Init(base_channel_
.get(),
425 base::Bind(&ChannelMultiplexer::OnWriteFailed
,
426 base::Unretained(this)));
429 DoCreatePendingChannels();
432 void ChannelMultiplexer::DoCreatePendingChannels() {
433 if (pending_channels_
.empty())
436 // Every time this function is called it connects a single channel and posts a
437 // separate task to connect other channels. This is necessary because the
438 // callback may destroy the multiplexer or somehow else modify
439 // |pending_channels_| list (e.g. call CancelChannelCreation()).
440 base::ThreadTaskRunnerHandle::Get()->PostTask(
441 FROM_HERE
, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels
,
442 weak_factory_
.GetWeakPtr()));
444 PendingChannel c
= pending_channels_
.front();
445 pending_channels_
.erase(pending_channels_
.begin());
446 scoped_ptr
<net::StreamSocket
> socket
;
447 if (base_channel_
.get())
448 socket
= GetOrCreateChannel(c
.name
)->CreateSocket();
449 c
.callback
.Run(socket
.Pass());
452 ChannelMultiplexer::MuxChannel
* ChannelMultiplexer::GetOrCreateChannel(
453 const std::string
& name
) {
454 // Check if we already have a channel with the requested name.
455 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
456 if (it
!= channels_
.end())
459 // Create a new channel if we haven't found existing one.
460 MuxChannel
* channel
= new MuxChannel(this, name
, next_channel_id_
);
462 channels_
[channel
->name()] = channel
;
467 void ChannelMultiplexer::OnWriteFailed(int error
) {
468 for (std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.begin();
469 it
!= channels_
.end(); ++it
) {
470 base::ThreadTaskRunnerHandle::Get()->PostTask(
471 FROM_HERE
, base::Bind(&ChannelMultiplexer::NotifyWriteFailed
,
472 weak_factory_
.GetWeakPtr(), it
->second
->name()));
476 void ChannelMultiplexer::NotifyWriteFailed(const std::string
& name
) {
477 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
478 if (it
!= channels_
.end()) {
479 it
->second
->OnWriteFailed();
483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
484 const base::Closure
& done_task
) {
485 if (!packet
->has_channel_id()) {
486 LOG(ERROR
) << "Received packet without channel_id.";
491 int receive_id
= packet
->channel_id();
492 MuxChannel
* channel
= NULL
;
493 std::map
<int, MuxChannel
*>::iterator it
=
494 channels_by_receive_id_
.find(receive_id
);
495 if (it
!= channels_by_receive_id_
.end()) {
496 channel
= it
->second
;
498 // This is a new |channel_id| we haven't seen before. Look it up by name.
499 if (!packet
->has_channel_name()) {
500 LOG(ERROR
) << "Received packet with unknown channel_id and "
501 "without channel_name.";
505 channel
= GetOrCreateChannel(packet
->channel_name());
506 channel
->set_receive_id(receive_id
);
507 channels_by_receive_id_
[receive_id
] = channel
;
510 channel
->OnIncomingPacket(packet
.Pass(), done_task
);
513 bool ChannelMultiplexer::DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
514 const base::Closure
& done_task
) {
515 return writer_
.Write(SerializeAndFrameMessage(*packet
), done_task
);
518 } // namespace protocol
519 } // namespace remoting