1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "net/socket/stream_socket.h"
18 #include "remoting/protocol/message_serialization.h"
24 const int kChannelIdUnknown
= -1;
25 const int kMaxPacketSize
= 1024;
29 PendingPacket(scoped_ptr
<MultiplexPacket
> packet
,
30 const base::Closure
& done_task
)
31 : packet(packet
.Pass()),
39 bool is_empty() { return pos
>= packet
->data().size(); }
41 int Read(char* buffer
, size_t size
) {
42 size
= std::min(size
, packet
->data().size() - pos
);
43 memcpy(buffer
, packet
->data().data() + pos
, size
);
49 scoped_ptr
<MultiplexPacket
> packet
;
50 base::Closure done_task
;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket
);
58 const char ChannelMultiplexer::kMuxChannelName
[] = "mux";
60 struct ChannelMultiplexer::PendingChannel
{
61 PendingChannel(const std::string
& name
,
62 const ChannelCreatedCallback
& callback
)
63 : name(name
), callback(callback
) {
66 ChannelCreatedCallback callback
;
69 class ChannelMultiplexer::MuxChannel
{
71 MuxChannel(ChannelMultiplexer
* multiplexer
, const std::string
& name
,
75 const std::string
& name() { return name_
; }
76 int receive_id() { return receive_id_
; }
77 void set_receive_id(int id
) { receive_id_
= id
; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr
<net::StreamSocket
> CreateSocket();
81 void OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
82 const base::Closure
& done_task
);
83 void OnBaseChannelError(int error
);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 bool DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
88 const base::Closure
& done_task
);
89 int DoRead(net::IOBuffer
* buffer
, int buffer_len
);
92 ChannelMultiplexer
* multiplexer_
;
98 std::list
<PendingPacket
*> pending_packets_
;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel
);
103 class ChannelMultiplexer::MuxSocket
: public net::StreamSocket
,
104 public base::NonThreadSafe
,
105 public base::SupportsWeakPtr
<MuxSocket
> {
107 MuxSocket(MuxChannel
* channel
);
108 ~MuxSocket() override
;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error
);
112 void OnPacketReceived();
114 // net::StreamSocket interface.
115 int Read(net::IOBuffer
* buffer
,
117 const net::CompletionCallback
& callback
) override
;
118 int Write(net::IOBuffer
* buffer
,
120 const net::CompletionCallback
& callback
) override
;
122 int SetReceiveBufferSize(int32 size
) override
{
124 return net::ERR_NOT_IMPLEMENTED
;
126 int SetSendBufferSize(int32 size
) override
{
128 return net::ERR_NOT_IMPLEMENTED
;
131 int Connect(const net::CompletionCallback
& callback
) override
{
133 return net::ERR_NOT_IMPLEMENTED
;
135 void Disconnect() override
{ NOTIMPLEMENTED(); }
136 bool IsConnected() const override
{
140 bool IsConnectedAndIdle() const override
{
144 int GetPeerAddress(net::IPEndPoint
* address
) const override
{
146 return net::ERR_NOT_IMPLEMENTED
;
148 int GetLocalAddress(net::IPEndPoint
* address
) const override
{
150 return net::ERR_NOT_IMPLEMENTED
;
152 const net::BoundNetLog
& NetLog() const override
{
156 void SetSubresourceSpeculation() override
{ NOTIMPLEMENTED(); }
157 void SetOmniboxSpeculation() override
{ NOTIMPLEMENTED(); }
158 bool WasEverUsed() const override
{ return true; }
159 bool UsingTCPFastOpen() const override
{ return false; }
160 bool WasNpnNegotiated() const override
{ return false; }
161 net::NextProto
GetNegotiatedProtocol() const override
{
162 return net::kProtoUnknown
;
164 bool GetSSLInfo(net::SSLInfo
* ssl_info
) override
{
168 void GetConnectionAttempts(net::ConnectionAttempts
* out
) const override
{
171 void ClearConnectionAttempts() override
{}
172 void AddConnectionAttempts(const net::ConnectionAttempts
& attempts
) override
{
176 MuxChannel
* channel_
;
178 int base_channel_error_
= net::OK
;
180 net::CompletionCallback read_callback_
;
181 scoped_refptr
<net::IOBuffer
> read_buffer_
;
182 int read_buffer_size_
;
186 net::CompletionCallback write_callback_
;
188 net::BoundNetLog net_log_
;
190 DISALLOW_COPY_AND_ASSIGN(MuxSocket
);
194 ChannelMultiplexer::MuxChannel::MuxChannel(
195 ChannelMultiplexer
* multiplexer
,
196 const std::string
& name
,
198 : multiplexer_(multiplexer
),
202 receive_id_(kChannelIdUnknown
),
206 ChannelMultiplexer::MuxChannel::~MuxChannel() {
207 // Socket must be destroyed before the channel.
209 STLDeleteElements(&pending_packets_
);
212 scoped_ptr
<net::StreamSocket
> ChannelMultiplexer::MuxChannel::CreateSocket() {
213 DCHECK(!socket_
); // Can't create more than one socket per channel.
214 scoped_ptr
<MuxSocket
> result(new MuxSocket(this));
215 socket_
= result
.get();
216 return result
.Pass();
219 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
220 scoped_ptr
<MultiplexPacket
> packet
,
221 const base::Closure
& done_task
) {
222 DCHECK_EQ(packet
->channel_id(), receive_id_
);
223 if (packet
->data().size() > 0) {
224 pending_packets_
.push_back(new PendingPacket(packet
.Pass(), done_task
));
226 // Notify the socket that we have more data.
227 socket_
->OnPacketReceived();
232 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error
) {
234 socket_
->OnBaseChannelError(error
);
237 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
242 bool ChannelMultiplexer::MuxChannel::DoWrite(
243 scoped_ptr
<MultiplexPacket
> packet
,
244 const base::Closure
& done_task
) {
245 packet
->set_channel_id(send_id_
);
247 packet
->set_channel_name(name_
);
250 return multiplexer_
->DoWrite(packet
.Pass(), done_task
);
253 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer
* buffer
,
256 while (buffer_len
> 0 && !pending_packets_
.empty()) {
257 DCHECK(!pending_packets_
.front()->is_empty());
258 int result
= pending_packets_
.front()->Read(
259 buffer
->data() + pos
, buffer_len
);
260 DCHECK_LE(result
, buffer_len
);
263 if (pending_packets_
.front()->is_empty()) {
264 delete pending_packets_
.front();
265 pending_packets_
.erase(pending_packets_
.begin());
271 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel
* channel
)
273 read_buffer_size_(0),
274 write_pending_(false),
278 ChannelMultiplexer::MuxSocket::~MuxSocket() {
279 channel_
->OnSocketDestroyed();
282 int ChannelMultiplexer::MuxSocket::Read(
283 net::IOBuffer
* buffer
, int buffer_len
,
284 const net::CompletionCallback
& callback
) {
285 DCHECK(CalledOnValidThread());
286 DCHECK(read_callback_
.is_null());
288 if (base_channel_error_
!= net::OK
)
289 return base_channel_error_
;
291 int result
= channel_
->DoRead(buffer
, buffer_len
);
293 read_buffer_
= buffer
;
294 read_buffer_size_
= buffer_len
;
295 read_callback_
= callback
;
296 return net::ERR_IO_PENDING
;
301 int ChannelMultiplexer::MuxSocket::Write(
302 net::IOBuffer
* buffer
, int buffer_len
,
303 const net::CompletionCallback
& callback
) {
304 DCHECK(CalledOnValidThread());
305 DCHECK(write_callback_
.is_null());
307 if (base_channel_error_
!= net::OK
)
308 return base_channel_error_
;
310 scoped_ptr
<MultiplexPacket
> packet(new MultiplexPacket());
311 size_t size
= std::min(kMaxPacketSize
, buffer_len
);
312 packet
->mutable_data()->assign(buffer
->data(), size
);
314 write_pending_
= true;
315 bool result
= channel_
->DoWrite(packet
.Pass(), base::Bind(
316 &ChannelMultiplexer::MuxSocket::OnWriteComplete
, AsWeakPtr()));
319 // Cannot complete the write, e.g. if the connection has been terminated.
320 return net::ERR_FAILED
;
323 // OnWriteComplete() might be called above synchronously.
324 if (write_pending_
) {
325 DCHECK(write_callback_
.is_null());
326 write_callback_
= callback
;
327 write_result_
= size
;
328 return net::ERR_IO_PENDING
;
334 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
335 write_pending_
= false;
336 if (!write_callback_
.is_null())
337 base::ResetAndReturn(&write_callback_
).Run(write_result_
);
341 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error
) {
342 base_channel_error_
= error
;
344 // Here only one of the read and write callbacks is called if both of them are
345 // pending. Ideally both of them should be called in that case, but that would
346 // require the second one to be called asynchronously which would complicate
347 // this code. Channels handle read and write errors the same way (see
348 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
349 // callbacks is enough.
351 if (!read_callback_
.is_null()) {
352 base::ResetAndReturn(&read_callback_
).Run(error
);
356 if (!write_callback_
.is_null())
357 base::ResetAndReturn(&write_callback_
).Run(error
);
360 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
361 if (!read_callback_
.is_null()) {
362 int result
= channel_
->DoRead(read_buffer_
.get(), read_buffer_size_
);
363 read_buffer_
= nullptr;
364 DCHECK_GT(result
, 0);
365 base::ResetAndReturn(&read_callback_
).Run(result
);
369 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory
* factory
,
370 const std::string
& base_channel_name
)
371 : base_channel_factory_(factory
),
372 base_channel_name_(base_channel_name
),
374 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket
,
375 base::Unretained(this)),
377 weak_factory_(this) {
380 ChannelMultiplexer::~ChannelMultiplexer() {
381 DCHECK(pending_channels_
.empty());
382 STLDeleteValues(&channels_
);
384 // Cancel creation of the base channel if it hasn't finished.
385 if (base_channel_factory_
)
386 base_channel_factory_
->CancelChannelCreation(base_channel_name_
);
389 void ChannelMultiplexer::CreateChannel(const std::string
& name
,
390 const ChannelCreatedCallback
& callback
) {
391 if (base_channel_
.get()) {
392 // Already have |base_channel_|. Create new multiplexed channel
394 callback
.Run(GetOrCreateChannel(name
)->CreateSocket());
395 } else if (!base_channel_
.get() && !base_channel_factory_
) {
396 // Fail synchronously if we failed to create |base_channel_|.
397 callback
.Run(nullptr);
399 // Still waiting for the |base_channel_|.
400 pending_channels_
.push_back(PendingChannel(name
, callback
));
402 // If this is the first multiplexed channel then create the base channel.
403 if (pending_channels_
.size() == 1U) {
404 base_channel_factory_
->CreateChannel(
406 base::Bind(&ChannelMultiplexer::OnBaseChannelReady
,
407 base::Unretained(this)));
412 void ChannelMultiplexer::CancelChannelCreation(const std::string
& name
) {
413 for (std::list
<PendingChannel
>::iterator it
= pending_channels_
.begin();
414 it
!= pending_channels_
.end(); ++it
) {
415 if (it
->name
== name
) {
416 pending_channels_
.erase(it
);
422 void ChannelMultiplexer::OnBaseChannelReady(
423 scoped_ptr
<net::StreamSocket
> socket
) {
424 base_channel_factory_
= nullptr;
425 base_channel_
= socket
.Pass();
427 if (base_channel_
.get()) {
428 // Initialize reader and writer.
429 reader_
.StartReading(base_channel_
.get(),
430 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
431 base::Unretained(this)));
432 writer_
.Init(base_channel_
.get(),
433 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
434 base::Unretained(this)));
437 DoCreatePendingChannels();
440 void ChannelMultiplexer::DoCreatePendingChannels() {
441 if (pending_channels_
.empty())
444 // Every time this function is called it connects a single channel and posts a
445 // separate task to connect other channels. This is necessary because the
446 // callback may destroy the multiplexer or somehow else modify
447 // |pending_channels_| list (e.g. call CancelChannelCreation()).
448 base::ThreadTaskRunnerHandle::Get()->PostTask(
449 FROM_HERE
, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels
,
450 weak_factory_
.GetWeakPtr()));
452 PendingChannel c
= pending_channels_
.front();
453 pending_channels_
.erase(pending_channels_
.begin());
454 scoped_ptr
<net::StreamSocket
> socket
;
455 if (base_channel_
.get())
456 socket
= GetOrCreateChannel(c
.name
)->CreateSocket();
457 c
.callback
.Run(socket
.Pass());
460 ChannelMultiplexer::MuxChannel
* ChannelMultiplexer::GetOrCreateChannel(
461 const std::string
& name
) {
462 // Check if we already have a channel with the requested name.
463 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
464 if (it
!= channels_
.end())
467 // Create a new channel if we haven't found existing one.
468 MuxChannel
* channel
= new MuxChannel(this, name
, next_channel_id_
);
470 channels_
[channel
->name()] = channel
;
475 void ChannelMultiplexer::OnBaseChannelError(int error
) {
476 for (std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.begin();
477 it
!= channels_
.end(); ++it
) {
478 base::ThreadTaskRunnerHandle::Get()->PostTask(
480 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError
,
481 weak_factory_
.GetWeakPtr(), it
->second
->name(), error
));
485 void ChannelMultiplexer::NotifyBaseChannelError(const std::string
& name
,
487 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
488 if (it
!= channels_
.end())
489 it
->second
->OnBaseChannelError(error
);
492 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
493 const base::Closure
& done_task
) {
494 DCHECK(packet
->has_channel_id());
495 if (!packet
->has_channel_id()) {
496 LOG(ERROR
) << "Received packet without channel_id.";
501 int receive_id
= packet
->channel_id();
502 MuxChannel
* channel
= nullptr;
503 std::map
<int, MuxChannel
*>::iterator it
=
504 channels_by_receive_id_
.find(receive_id
);
505 if (it
!= channels_by_receive_id_
.end()) {
506 channel
= it
->second
;
508 // This is a new |channel_id| we haven't seen before. Look it up by name.
509 if (!packet
->has_channel_name()) {
510 LOG(ERROR
) << "Received packet with unknown channel_id and "
511 "without channel_name.";
515 channel
= GetOrCreateChannel(packet
->channel_name());
516 channel
->set_receive_id(receive_id
);
517 channels_by_receive_id_
[receive_id
] = channel
;
520 channel
->OnIncomingPacket(packet
.Pass(), done_task
);
523 bool ChannelMultiplexer::DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
524 const base::Closure
& done_task
) {
525 return writer_
.Write(SerializeAndFrameMessage(*packet
), done_task
);
528 } // namespace protocol
529 } // namespace remoting