1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/message_serialization.h"
23 const int kChannelIdUnknown
= -1;
24 const int kMaxPacketSize
= 1024;
28 PendingPacket(scoped_ptr
<MultiplexPacket
> packet
,
29 const base::Closure
& done_task
)
30 : packet(packet
.Pass()),
38 bool is_empty() { return pos
>= packet
->data().size(); }
40 int Read(char* buffer
, size_t size
) {
41 size
= std::min(size
, packet
->data().size() - pos
);
42 memcpy(buffer
, packet
->data().data() + pos
, size
);
48 scoped_ptr
<MultiplexPacket
> packet
;
49 base::Closure done_task
;
52 DISALLOW_COPY_AND_ASSIGN(PendingPacket
);
57 const char ChannelMultiplexer::kMuxChannelName
[] = "mux";
59 struct ChannelMultiplexer::PendingChannel
{
60 PendingChannel(const std::string
& name
,
61 const ChannelCreatedCallback
& callback
)
62 : name(name
), callback(callback
) {
65 ChannelCreatedCallback callback
;
68 class ChannelMultiplexer::MuxChannel
{
70 MuxChannel(ChannelMultiplexer
* multiplexer
, const std::string
& name
,
74 const std::string
& name() { return name_
; }
75 int receive_id() { return receive_id_
; }
76 void set_receive_id(int id
) { receive_id_
= id
; }
78 // Called by ChannelMultiplexer.
79 scoped_ptr
<net::StreamSocket
> CreateSocket();
80 void OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
81 const base::Closure
& done_task
);
84 // Called by MuxSocket.
85 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
87 const base::Closure
& done_task
);
88 int DoRead(net::IOBuffer
* buffer
, int buffer_len
);
91 ChannelMultiplexer
* multiplexer_
;
97 std::list
<PendingPacket
*> pending_packets_
;
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel
);
102 class ChannelMultiplexer::MuxSocket
: public net::StreamSocket
,
103 public base::NonThreadSafe
,
104 public base::SupportsWeakPtr
<MuxSocket
> {
106 MuxSocket(MuxChannel
* channel
);
107 ~MuxSocket() override
;
109 void OnWriteComplete();
110 void OnWriteFailed();
111 void OnPacketReceived();
113 // net::StreamSocket interface.
114 int Read(net::IOBuffer
* buffer
,
116 const net::CompletionCallback
& callback
) override
;
117 int Write(net::IOBuffer
* buffer
,
119 const net::CompletionCallback
& callback
) override
;
121 int SetReceiveBufferSize(int32 size
) override
{
123 return net::ERR_NOT_IMPLEMENTED
;
125 int SetSendBufferSize(int32 size
) override
{
127 return net::ERR_NOT_IMPLEMENTED
;
130 int Connect(const net::CompletionCallback
& callback
) override
{
132 return net::ERR_NOT_IMPLEMENTED
;
134 void Disconnect() override
{ NOTIMPLEMENTED(); }
135 bool IsConnected() const override
{
139 bool IsConnectedAndIdle() const override
{
143 int GetPeerAddress(net::IPEndPoint
* address
) const override
{
145 return net::ERR_NOT_IMPLEMENTED
;
147 int GetLocalAddress(net::IPEndPoint
* address
) const override
{
149 return net::ERR_NOT_IMPLEMENTED
;
151 const net::BoundNetLog
& NetLog() const override
{
155 void SetSubresourceSpeculation() override
{ NOTIMPLEMENTED(); }
156 void SetOmniboxSpeculation() override
{ NOTIMPLEMENTED(); }
157 bool WasEverUsed() const override
{ return true; }
158 bool UsingTCPFastOpen() const override
{ return false; }
159 bool WasNpnNegotiated() const override
{ return false; }
160 net::NextProto
GetNegotiatedProtocol() const override
{
161 return net::kProtoUnknown
;
163 bool GetSSLInfo(net::SSLInfo
* ssl_info
) override
{
169 MuxChannel
* channel_
;
171 net::CompletionCallback read_callback_
;
172 scoped_refptr
<net::IOBuffer
> read_buffer_
;
173 int read_buffer_size_
;
177 net::CompletionCallback write_callback_
;
179 net::BoundNetLog net_log_
;
181 DISALLOW_COPY_AND_ASSIGN(MuxSocket
);
185 ChannelMultiplexer::MuxChannel::MuxChannel(
186 ChannelMultiplexer
* multiplexer
,
187 const std::string
& name
,
189 : multiplexer_(multiplexer
),
193 receive_id_(kChannelIdUnknown
),
197 ChannelMultiplexer::MuxChannel::~MuxChannel() {
198 // Socket must be destroyed before the channel.
200 STLDeleteElements(&pending_packets_
);
203 scoped_ptr
<net::StreamSocket
> ChannelMultiplexer::MuxChannel::CreateSocket() {
204 DCHECK(!socket_
); // Can't create more than one socket per channel.
205 scoped_ptr
<MuxSocket
> result(new MuxSocket(this));
206 socket_
= result
.get();
207 return result
.Pass();
210 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
211 scoped_ptr
<MultiplexPacket
> packet
,
212 const base::Closure
& done_task
) {
213 DCHECK_EQ(packet
->channel_id(), receive_id_
);
214 if (packet
->data().size() > 0) {
215 pending_packets_
.push_back(new PendingPacket(packet
.Pass(), done_task
));
217 // Notify the socket that we have more data.
218 socket_
->OnPacketReceived();
223 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
225 socket_
->OnWriteFailed();
228 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
233 bool ChannelMultiplexer::MuxChannel::DoWrite(
234 scoped_ptr
<MultiplexPacket
> packet
,
235 const base::Closure
& done_task
) {
236 packet
->set_channel_id(send_id_
);
238 packet
->set_channel_name(name_
);
241 return multiplexer_
->DoWrite(packet
.Pass(), done_task
);
244 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer
* buffer
,
247 while (buffer_len
> 0 && !pending_packets_
.empty()) {
248 DCHECK(!pending_packets_
.front()->is_empty());
249 int result
= pending_packets_
.front()->Read(
250 buffer
->data() + pos
, buffer_len
);
251 DCHECK_LE(result
, buffer_len
);
254 if (pending_packets_
.front()->is_empty()) {
255 delete pending_packets_
.front();
256 pending_packets_
.erase(pending_packets_
.begin());
262 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel
* channel
)
264 read_buffer_size_(0),
265 write_pending_(false),
269 ChannelMultiplexer::MuxSocket::~MuxSocket() {
270 channel_
->OnSocketDestroyed();
273 int ChannelMultiplexer::MuxSocket::Read(
274 net::IOBuffer
* buffer
, int buffer_len
,
275 const net::CompletionCallback
& callback
) {
276 DCHECK(CalledOnValidThread());
277 DCHECK(read_callback_
.is_null());
279 int result
= channel_
->DoRead(buffer
, buffer_len
);
281 read_buffer_
= buffer
;
282 read_buffer_size_
= buffer_len
;
283 read_callback_
= callback
;
284 return net::ERR_IO_PENDING
;
289 int ChannelMultiplexer::MuxSocket::Write(
290 net::IOBuffer
* buffer
, int buffer_len
,
291 const net::CompletionCallback
& callback
) {
292 DCHECK(CalledOnValidThread());
294 scoped_ptr
<MultiplexPacket
> packet(new MultiplexPacket());
295 size_t size
= std::min(kMaxPacketSize
, buffer_len
);
296 packet
->mutable_data()->assign(buffer
->data(), size
);
298 write_pending_
= true;
299 bool result
= channel_
->DoWrite(packet
.Pass(), base::Bind(
300 &ChannelMultiplexer::MuxSocket::OnWriteComplete
, AsWeakPtr()));
303 // Cannot complete the write, e.g. if the connection has been terminated.
304 return net::ERR_FAILED
;
307 // OnWriteComplete() might be called above synchronously.
308 if (write_pending_
) {
309 DCHECK(write_callback_
.is_null());
310 write_callback_
= callback
;
311 write_result_
= size
;
312 return net::ERR_IO_PENDING
;
318 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
319 write_pending_
= false;
320 if (!write_callback_
.is_null()) {
321 net::CompletionCallback cb
;
322 std::swap(cb
, write_callback_
);
323 cb
.Run(write_result_
);
327 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
328 if (!write_callback_
.is_null()) {
329 net::CompletionCallback cb
;
330 std::swap(cb
, write_callback_
);
331 cb
.Run(net::ERR_FAILED
);
335 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
336 if (!read_callback_
.is_null()) {
337 int result
= channel_
->DoRead(read_buffer_
.get(), read_buffer_size_
);
339 DCHECK_GT(result
, 0);
340 net::CompletionCallback cb
;
341 std::swap(cb
, read_callback_
);
346 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory
* factory
,
347 const std::string
& base_channel_name
)
348 : base_channel_factory_(factory
),
349 base_channel_name_(base_channel_name
),
351 weak_factory_(this) {
354 ChannelMultiplexer::~ChannelMultiplexer() {
355 DCHECK(pending_channels_
.empty());
356 STLDeleteValues(&channels_
);
358 // Cancel creation of the base channel if it hasn't finished.
359 if (base_channel_factory_
)
360 base_channel_factory_
->CancelChannelCreation(base_channel_name_
);
363 void ChannelMultiplexer::CreateChannel(const std::string
& name
,
364 const ChannelCreatedCallback
& callback
) {
365 if (base_channel_
.get()) {
366 // Already have |base_channel_|. Create new multiplexed channel
368 callback
.Run(GetOrCreateChannel(name
)->CreateSocket());
369 } else if (!base_channel_
.get() && !base_channel_factory_
) {
370 // Fail synchronously if we failed to create |base_channel_|.
371 callback
.Run(nullptr);
373 // Still waiting for the |base_channel_|.
374 pending_channels_
.push_back(PendingChannel(name
, callback
));
376 // If this is the first multiplexed channel then create the base channel.
377 if (pending_channels_
.size() == 1U) {
378 base_channel_factory_
->CreateChannel(
380 base::Bind(&ChannelMultiplexer::OnBaseChannelReady
,
381 base::Unretained(this)));
386 void ChannelMultiplexer::CancelChannelCreation(const std::string
& name
) {
387 for (std::list
<PendingChannel
>::iterator it
= pending_channels_
.begin();
388 it
!= pending_channels_
.end(); ++it
) {
389 if (it
->name
== name
) {
390 pending_channels_
.erase(it
);
396 void ChannelMultiplexer::OnBaseChannelReady(
397 scoped_ptr
<net::StreamSocket
> socket
) {
398 base_channel_factory_
= NULL
;
399 base_channel_
= socket
.Pass();
401 if (base_channel_
.get()) {
402 // Initialize reader and writer.
403 reader_
.Init(base_channel_
.get(),
404 base::Bind(&ChannelMultiplexer::OnIncomingPacket
,
405 base::Unretained(this)));
406 writer_
.Init(base_channel_
.get(),
407 base::Bind(&ChannelMultiplexer::OnWriteFailed
,
408 base::Unretained(this)));
411 DoCreatePendingChannels();
414 void ChannelMultiplexer::DoCreatePendingChannels() {
415 if (pending_channels_
.empty())
418 // Every time this function is called it connects a single channel and posts a
419 // separate task to connect other channels. This is necessary because the
420 // callback may destroy the multiplexer or somehow else modify
421 // |pending_channels_| list (e.g. call CancelChannelCreation()).
422 base::ThreadTaskRunnerHandle::Get()->PostTask(
423 FROM_HERE
, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels
,
424 weak_factory_
.GetWeakPtr()));
426 PendingChannel c
= pending_channels_
.front();
427 pending_channels_
.erase(pending_channels_
.begin());
428 scoped_ptr
<net::StreamSocket
> socket
;
429 if (base_channel_
.get())
430 socket
= GetOrCreateChannel(c
.name
)->CreateSocket();
431 c
.callback
.Run(socket
.Pass());
434 ChannelMultiplexer::MuxChannel
* ChannelMultiplexer::GetOrCreateChannel(
435 const std::string
& name
) {
436 // Check if we already have a channel with the requested name.
437 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
438 if (it
!= channels_
.end())
441 // Create a new channel if we haven't found existing one.
442 MuxChannel
* channel
= new MuxChannel(this, name
, next_channel_id_
);
444 channels_
[channel
->name()] = channel
;
449 void ChannelMultiplexer::OnWriteFailed(int error
) {
450 for (std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.begin();
451 it
!= channels_
.end(); ++it
) {
452 base::ThreadTaskRunnerHandle::Get()->PostTask(
453 FROM_HERE
, base::Bind(&ChannelMultiplexer::NotifyWriteFailed
,
454 weak_factory_
.GetWeakPtr(), it
->second
->name()));
458 void ChannelMultiplexer::NotifyWriteFailed(const std::string
& name
) {
459 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
460 if (it
!= channels_
.end()) {
461 it
->second
->OnWriteFailed();
465 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
466 const base::Closure
& done_task
) {
467 DCHECK(packet
->has_channel_id());
468 if (!packet
->has_channel_id()) {
469 LOG(ERROR
) << "Received packet without channel_id.";
474 int receive_id
= packet
->channel_id();
475 MuxChannel
* channel
= NULL
;
476 std::map
<int, MuxChannel
*>::iterator it
=
477 channels_by_receive_id_
.find(receive_id
);
478 if (it
!= channels_by_receive_id_
.end()) {
479 channel
= it
->second
;
481 // This is a new |channel_id| we haven't seen before. Look it up by name.
482 if (!packet
->has_channel_name()) {
483 LOG(ERROR
) << "Received packet with unknown channel_id and "
484 "without channel_name.";
488 channel
= GetOrCreateChannel(packet
->channel_name());
489 channel
->set_receive_id(receive_id
);
490 channels_by_receive_id_
[receive_id
] = channel
;
493 channel
->OnIncomingPacket(packet
.Pass(), done_task
);
496 bool ChannelMultiplexer::DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
497 const base::Closure
& done_task
) {
498 return writer_
.Write(SerializeAndFrameMessage(*packet
), done_task
);
501 } // namespace protocol
502 } // namespace remoting