1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "remoting/protocol/message_serialization.h"
18 #include "remoting/protocol/p2p_stream_socket.h"
24 const int kChannelIdUnknown
= -1;
25 const int kMaxPacketSize
= 1024;
29 PendingPacket(scoped_ptr
<MultiplexPacket
> packet
,
30 const base::Closure
& done_task
)
31 : packet(packet
.Pass()),
39 bool is_empty() { return pos
>= packet
->data().size(); }
41 int Read(char* buffer
, size_t size
) {
42 size
= std::min(size
, packet
->data().size() - pos
);
43 memcpy(buffer
, packet
->data().data() + pos
, size
);
49 scoped_ptr
<MultiplexPacket
> packet
;
50 base::Closure done_task
;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket
);
58 const char ChannelMultiplexer::kMuxChannelName
[] = "mux";
60 struct ChannelMultiplexer::PendingChannel
{
61 PendingChannel(const std::string
& name
,
62 const ChannelCreatedCallback
& callback
)
63 : name(name
), callback(callback
) {
66 ChannelCreatedCallback callback
;
69 class ChannelMultiplexer::MuxChannel
{
71 MuxChannel(ChannelMultiplexer
* multiplexer
, const std::string
& name
,
75 const std::string
& name() { return name_
; }
76 int receive_id() { return receive_id_
; }
77 void set_receive_id(int id
) { receive_id_
= id
; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr
<P2PStreamSocket
> CreateSocket();
81 void OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
82 const base::Closure
& done_task
);
83 void OnBaseChannelError(int error
);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 void DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
88 const base::Closure
& done_task
);
89 int DoRead(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
);
92 ChannelMultiplexer
* multiplexer_
;
98 std::list
<PendingPacket
*> pending_packets_
;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel
);
103 class ChannelMultiplexer::MuxSocket
: public P2PStreamSocket
,
104 public base::NonThreadSafe
,
105 public base::SupportsWeakPtr
<MuxSocket
> {
107 MuxSocket(MuxChannel
* channel
);
108 ~MuxSocket() override
;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error
);
112 void OnPacketReceived();
114 // P2PStreamSocket interface.
115 int Read(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
116 const net::CompletionCallback
& callback
) override
;
117 int Write(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
118 const net::CompletionCallback
& callback
) override
;
121 MuxChannel
* channel_
;
123 int base_channel_error_
= net::OK
;
125 net::CompletionCallback read_callback_
;
126 scoped_refptr
<net::IOBuffer
> read_buffer_
;
127 int read_buffer_size_
;
131 net::CompletionCallback write_callback_
;
133 DISALLOW_COPY_AND_ASSIGN(MuxSocket
);
137 ChannelMultiplexer::MuxChannel::MuxChannel(
138 ChannelMultiplexer
* multiplexer
,
139 const std::string
& name
,
141 : multiplexer_(multiplexer
),
145 receive_id_(kChannelIdUnknown
),
149 ChannelMultiplexer::MuxChannel::~MuxChannel() {
150 // Socket must be destroyed before the channel.
152 STLDeleteElements(&pending_packets_
);
155 scoped_ptr
<P2PStreamSocket
> ChannelMultiplexer::MuxChannel::CreateSocket() {
156 DCHECK(!socket_
); // Can't create more than one socket per channel.
157 scoped_ptr
<MuxSocket
> result(new MuxSocket(this));
158 socket_
= result
.get();
159 return result
.Pass();
162 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
163 scoped_ptr
<MultiplexPacket
> packet
,
164 const base::Closure
& done_task
) {
165 DCHECK_EQ(packet
->channel_id(), receive_id_
);
166 if (packet
->data().size() > 0) {
167 pending_packets_
.push_back(new PendingPacket(packet
.Pass(), done_task
));
169 // Notify the socket that we have more data.
170 socket_
->OnPacketReceived();
175 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error
) {
177 socket_
->OnBaseChannelError(error
);
180 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
185 void ChannelMultiplexer::MuxChannel::DoWrite(
186 scoped_ptr
<MultiplexPacket
> packet
,
187 const base::Closure
& done_task
) {
188 packet
->set_channel_id(send_id_
);
190 packet
->set_channel_name(name_
);
193 multiplexer_
->DoWrite(packet
.Pass(), done_task
);
196 int ChannelMultiplexer::MuxChannel::DoRead(
197 const scoped_refptr
<net::IOBuffer
>& buffer
,
200 while (buffer_len
> 0 && !pending_packets_
.empty()) {
201 DCHECK(!pending_packets_
.front()->is_empty());
202 int result
= pending_packets_
.front()->Read(
203 buffer
->data() + pos
, buffer_len
);
204 DCHECK_LE(result
, buffer_len
);
207 if (pending_packets_
.front()->is_empty()) {
208 delete pending_packets_
.front();
209 pending_packets_
.erase(pending_packets_
.begin());
215 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel
* channel
)
217 read_buffer_size_(0),
218 write_pending_(false),
222 ChannelMultiplexer::MuxSocket::~MuxSocket() {
223 channel_
->OnSocketDestroyed();
226 int ChannelMultiplexer::MuxSocket::Read(
227 const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
228 const net::CompletionCallback
& callback
) {
229 DCHECK(CalledOnValidThread());
230 DCHECK(read_callback_
.is_null());
232 if (base_channel_error_
!= net::OK
)
233 return base_channel_error_
;
235 int result
= channel_
->DoRead(buffer
, buffer_len
);
237 read_buffer_
= buffer
;
238 read_buffer_size_
= buffer_len
;
239 read_callback_
= callback
;
240 return net::ERR_IO_PENDING
;
245 int ChannelMultiplexer::MuxSocket::Write(
246 const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
247 const net::CompletionCallback
& callback
) {
248 DCHECK(CalledOnValidThread());
249 DCHECK(write_callback_
.is_null());
251 if (base_channel_error_
!= net::OK
)
252 return base_channel_error_
;
254 scoped_ptr
<MultiplexPacket
> packet(new MultiplexPacket());
255 size_t size
= std::min(kMaxPacketSize
, buffer_len
);
256 packet
->mutable_data()->assign(buffer
->data(), size
);
258 write_pending_
= true;
259 channel_
->DoWrite(packet
.Pass(), base::Bind(
260 &ChannelMultiplexer::MuxSocket::OnWriteComplete
, AsWeakPtr()));
262 // OnWriteComplete() might be called above synchronously.
263 if (write_pending_
) {
264 DCHECK(write_callback_
.is_null());
265 write_callback_
= callback
;
266 write_result_
= size
;
267 return net::ERR_IO_PENDING
;
273 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
274 write_pending_
= false;
275 if (!write_callback_
.is_null())
276 base::ResetAndReturn(&write_callback_
).Run(write_result_
);
280 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error
) {
281 base_channel_error_
= error
;
283 // Here only one of the read and write callbacks is called if both of them are
284 // pending. Ideally both of them should be called in that case, but that would
285 // require the second one to be called asynchronously which would complicate
286 // this code. Channels handle read and write errors the same way (see
287 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
288 // callbacks is enough.
290 if (!read_callback_
.is_null()) {
291 base::ResetAndReturn(&read_callback_
).Run(error
);
295 if (!write_callback_
.is_null())
296 base::ResetAndReturn(&write_callback_
).Run(error
);
299 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
300 if (!read_callback_
.is_null()) {
301 int result
= channel_
->DoRead(read_buffer_
.get(), read_buffer_size_
);
302 read_buffer_
= nullptr;
303 DCHECK_GT(result
, 0);
304 base::ResetAndReturn(&read_callback_
).Run(result
);
308 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory
* factory
,
309 const std::string
& base_channel_name
)
310 : base_channel_factory_(factory
),
311 base_channel_name_(base_channel_name
),
313 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket
,
314 base::Unretained(this)),
316 weak_factory_(this) {
319 ChannelMultiplexer::~ChannelMultiplexer() {
320 DCHECK(pending_channels_
.empty());
321 STLDeleteValues(&channels_
);
323 // Cancel creation of the base channel if it hasn't finished.
324 if (base_channel_factory_
)
325 base_channel_factory_
->CancelChannelCreation(base_channel_name_
);
328 void ChannelMultiplexer::CreateChannel(const std::string
& name
,
329 const ChannelCreatedCallback
& callback
) {
330 if (base_channel_
.get()) {
331 // Already have |base_channel_|. Create new multiplexed channel
333 callback
.Run(GetOrCreateChannel(name
)->CreateSocket());
334 } else if (!base_channel_
.get() && !base_channel_factory_
) {
335 // Fail synchronously if we failed to create |base_channel_|.
336 callback
.Run(nullptr);
338 // Still waiting for the |base_channel_|.
339 pending_channels_
.push_back(PendingChannel(name
, callback
));
341 // If this is the first multiplexed channel then create the base channel.
342 if (pending_channels_
.size() == 1U) {
343 base_channel_factory_
->CreateChannel(
345 base::Bind(&ChannelMultiplexer::OnBaseChannelReady
,
346 base::Unretained(this)));
351 void ChannelMultiplexer::CancelChannelCreation(const std::string
& name
) {
352 for (std::list
<PendingChannel
>::iterator it
= pending_channels_
.begin();
353 it
!= pending_channels_
.end(); ++it
) {
354 if (it
->name
== name
) {
355 pending_channels_
.erase(it
);
361 void ChannelMultiplexer::OnBaseChannelReady(
362 scoped_ptr
<P2PStreamSocket
> socket
) {
363 base_channel_factory_
= nullptr;
364 base_channel_
= socket
.Pass();
366 if (base_channel_
.get()) {
367 // Initialize reader and writer.
368 reader_
.StartReading(base_channel_
.get(),
369 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
370 base::Unretained(this)));
371 writer_
.Init(base::Bind(&P2PStreamSocket::Write
,
372 base::Unretained(base_channel_
.get())),
373 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
374 base::Unretained(this)));
377 DoCreatePendingChannels();
380 void ChannelMultiplexer::DoCreatePendingChannels() {
381 if (pending_channels_
.empty())
384 // Every time this function is called it connects a single channel and posts a
385 // separate task to connect other channels. This is necessary because the
386 // callback may destroy the multiplexer or somehow else modify
387 // |pending_channels_| list (e.g. call CancelChannelCreation()).
388 base::ThreadTaskRunnerHandle::Get()->PostTask(
389 FROM_HERE
, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels
,
390 weak_factory_
.GetWeakPtr()));
392 PendingChannel c
= pending_channels_
.front();
393 pending_channels_
.erase(pending_channels_
.begin());
394 scoped_ptr
<P2PStreamSocket
> socket
;
395 if (base_channel_
.get())
396 socket
= GetOrCreateChannel(c
.name
)->CreateSocket();
397 c
.callback
.Run(socket
.Pass());
400 ChannelMultiplexer::MuxChannel
* ChannelMultiplexer::GetOrCreateChannel(
401 const std::string
& name
) {
402 // Check if we already have a channel with the requested name.
403 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
404 if (it
!= channels_
.end())
407 // Create a new channel if we haven't found existing one.
408 MuxChannel
* channel
= new MuxChannel(this, name
, next_channel_id_
);
410 channels_
[channel
->name()] = channel
;
415 void ChannelMultiplexer::OnBaseChannelError(int error
) {
416 for (std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.begin();
417 it
!= channels_
.end(); ++it
) {
418 base::ThreadTaskRunnerHandle::Get()->PostTask(
420 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError
,
421 weak_factory_
.GetWeakPtr(), it
->second
->name(), error
));
425 void ChannelMultiplexer::NotifyBaseChannelError(const std::string
& name
,
427 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
428 if (it
!= channels_
.end())
429 it
->second
->OnBaseChannelError(error
);
432 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
433 const base::Closure
& done_task
) {
434 DCHECK(packet
->has_channel_id());
435 if (!packet
->has_channel_id()) {
436 LOG(ERROR
) << "Received packet without channel_id.";
441 int receive_id
= packet
->channel_id();
442 MuxChannel
* channel
= nullptr;
443 std::map
<int, MuxChannel
*>::iterator it
=
444 channels_by_receive_id_
.find(receive_id
);
445 if (it
!= channels_by_receive_id_
.end()) {
446 channel
= it
->second
;
448 // This is a new |channel_id| we haven't seen before. Look it up by name.
449 if (!packet
->has_channel_name()) {
450 LOG(ERROR
) << "Received packet with unknown channel_id and "
451 "without channel_name.";
455 channel
= GetOrCreateChannel(packet
->channel_name());
456 channel
->set_receive_id(receive_id
);
457 channels_by_receive_id_
[receive_id
] = channel
;
460 channel
->OnIncomingPacket(packet
.Pass(), done_task
);
463 void ChannelMultiplexer::DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
464 const base::Closure
& done_task
) {
465 writer_
.Write(SerializeAndFrameMessage(*packet
), done_task
);
468 } // namespace protocol
469 } // namespace remoting