1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "remoting/protocol/message_serialization.h"
18 #include "remoting/protocol/p2p_stream_socket.h"
24 const int kChannelIdUnknown
= -1;
25 const int kMaxPacketSize
= 1024;
29 PendingPacket(scoped_ptr
<MultiplexPacket
> packet
,
30 const base::Closure
& done_task
)
31 : packet(packet
.Pass()),
39 bool is_empty() { return pos
>= packet
->data().size(); }
41 int Read(char* buffer
, size_t size
) {
42 size
= std::min(size
, packet
->data().size() - pos
);
43 memcpy(buffer
, packet
->data().data() + pos
, size
);
49 scoped_ptr
<MultiplexPacket
> packet
;
50 base::Closure done_task
;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket
);
58 const char ChannelMultiplexer::kMuxChannelName
[] = "mux";
60 struct ChannelMultiplexer::PendingChannel
{
61 PendingChannel(const std::string
& name
,
62 const ChannelCreatedCallback
& callback
)
63 : name(name
), callback(callback
) {
66 ChannelCreatedCallback callback
;
69 class ChannelMultiplexer::MuxChannel
{
71 MuxChannel(ChannelMultiplexer
* multiplexer
, const std::string
& name
,
75 const std::string
& name() { return name_
; }
76 int receive_id() { return receive_id_
; }
77 void set_receive_id(int id
) { receive_id_
= id
; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr
<P2PStreamSocket
> CreateSocket();
81 void OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
82 const base::Closure
& done_task
);
83 void OnBaseChannelError(int error
);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 bool DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
88 const base::Closure
& done_task
);
89 int DoRead(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
);
92 ChannelMultiplexer
* multiplexer_
;
98 std::list
<PendingPacket
*> pending_packets_
;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel
);
103 class ChannelMultiplexer::MuxSocket
: public P2PStreamSocket
,
104 public base::NonThreadSafe
,
105 public base::SupportsWeakPtr
<MuxSocket
> {
107 MuxSocket(MuxChannel
* channel
);
108 ~MuxSocket() override
;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error
);
112 void OnPacketReceived();
114 // P2PStreamSocket interface.
115 int Read(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
116 const net::CompletionCallback
& callback
) override
;
117 int Write(const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
118 const net::CompletionCallback
& callback
) override
;
121 MuxChannel
* channel_
;
123 int base_channel_error_
= net::OK
;
125 net::CompletionCallback read_callback_
;
126 scoped_refptr
<net::IOBuffer
> read_buffer_
;
127 int read_buffer_size_
;
131 net::CompletionCallback write_callback_
;
133 DISALLOW_COPY_AND_ASSIGN(MuxSocket
);
137 ChannelMultiplexer::MuxChannel::MuxChannel(
138 ChannelMultiplexer
* multiplexer
,
139 const std::string
& name
,
141 : multiplexer_(multiplexer
),
145 receive_id_(kChannelIdUnknown
),
149 ChannelMultiplexer::MuxChannel::~MuxChannel() {
150 // Socket must be destroyed before the channel.
152 STLDeleteElements(&pending_packets_
);
155 scoped_ptr
<P2PStreamSocket
> ChannelMultiplexer::MuxChannel::CreateSocket() {
156 DCHECK(!socket_
); // Can't create more than one socket per channel.
157 scoped_ptr
<MuxSocket
> result(new MuxSocket(this));
158 socket_
= result
.get();
159 return result
.Pass();
162 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
163 scoped_ptr
<MultiplexPacket
> packet
,
164 const base::Closure
& done_task
) {
165 DCHECK_EQ(packet
->channel_id(), receive_id_
);
166 if (packet
->data().size() > 0) {
167 pending_packets_
.push_back(new PendingPacket(packet
.Pass(), done_task
));
169 // Notify the socket that we have more data.
170 socket_
->OnPacketReceived();
175 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error
) {
177 socket_
->OnBaseChannelError(error
);
180 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
185 bool ChannelMultiplexer::MuxChannel::DoWrite(
186 scoped_ptr
<MultiplexPacket
> packet
,
187 const base::Closure
& done_task
) {
188 packet
->set_channel_id(send_id_
);
190 packet
->set_channel_name(name_
);
193 return multiplexer_
->DoWrite(packet
.Pass(), done_task
);
196 int ChannelMultiplexer::MuxChannel::DoRead(
197 const scoped_refptr
<net::IOBuffer
>& buffer
,
200 while (buffer_len
> 0 && !pending_packets_
.empty()) {
201 DCHECK(!pending_packets_
.front()->is_empty());
202 int result
= pending_packets_
.front()->Read(
203 buffer
->data() + pos
, buffer_len
);
204 DCHECK_LE(result
, buffer_len
);
207 if (pending_packets_
.front()->is_empty()) {
208 delete pending_packets_
.front();
209 pending_packets_
.erase(pending_packets_
.begin());
215 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel
* channel
)
217 read_buffer_size_(0),
218 write_pending_(false),
222 ChannelMultiplexer::MuxSocket::~MuxSocket() {
223 channel_
->OnSocketDestroyed();
226 int ChannelMultiplexer::MuxSocket::Read(
227 const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
228 const net::CompletionCallback
& callback
) {
229 DCHECK(CalledOnValidThread());
230 DCHECK(read_callback_
.is_null());
232 if (base_channel_error_
!= net::OK
)
233 return base_channel_error_
;
235 int result
= channel_
->DoRead(buffer
, buffer_len
);
237 read_buffer_
= buffer
;
238 read_buffer_size_
= buffer_len
;
239 read_callback_
= callback
;
240 return net::ERR_IO_PENDING
;
245 int ChannelMultiplexer::MuxSocket::Write(
246 const scoped_refptr
<net::IOBuffer
>& buffer
, int buffer_len
,
247 const net::CompletionCallback
& callback
) {
248 DCHECK(CalledOnValidThread());
249 DCHECK(write_callback_
.is_null());
251 if (base_channel_error_
!= net::OK
)
252 return base_channel_error_
;
254 scoped_ptr
<MultiplexPacket
> packet(new MultiplexPacket());
255 size_t size
= std::min(kMaxPacketSize
, buffer_len
);
256 packet
->mutable_data()->assign(buffer
->data(), size
);
258 write_pending_
= true;
259 bool result
= channel_
->DoWrite(packet
.Pass(), base::Bind(
260 &ChannelMultiplexer::MuxSocket::OnWriteComplete
, AsWeakPtr()));
263 // Cannot complete the write, e.g. if the connection has been terminated.
264 return net::ERR_FAILED
;
267 // OnWriteComplete() might be called above synchronously.
268 if (write_pending_
) {
269 DCHECK(write_callback_
.is_null());
270 write_callback_
= callback
;
271 write_result_
= size
;
272 return net::ERR_IO_PENDING
;
278 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
279 write_pending_
= false;
280 if (!write_callback_
.is_null())
281 base::ResetAndReturn(&write_callback_
).Run(write_result_
);
285 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error
) {
286 base_channel_error_
= error
;
288 // Here only one of the read and write callbacks is called if both of them are
289 // pending. Ideally both of them should be called in that case, but that would
290 // require the second one to be called asynchronously which would complicate
291 // this code. Channels handle read and write errors the same way (see
292 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
293 // callbacks is enough.
295 if (!read_callback_
.is_null()) {
296 base::ResetAndReturn(&read_callback_
).Run(error
);
300 if (!write_callback_
.is_null())
301 base::ResetAndReturn(&write_callback_
).Run(error
);
304 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
305 if (!read_callback_
.is_null()) {
306 int result
= channel_
->DoRead(read_buffer_
.get(), read_buffer_size_
);
307 read_buffer_
= nullptr;
308 DCHECK_GT(result
, 0);
309 base::ResetAndReturn(&read_callback_
).Run(result
);
313 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory
* factory
,
314 const std::string
& base_channel_name
)
315 : base_channel_factory_(factory
),
316 base_channel_name_(base_channel_name
),
318 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket
,
319 base::Unretained(this)),
321 weak_factory_(this) {
324 ChannelMultiplexer::~ChannelMultiplexer() {
325 DCHECK(pending_channels_
.empty());
326 STLDeleteValues(&channels_
);
328 // Cancel creation of the base channel if it hasn't finished.
329 if (base_channel_factory_
)
330 base_channel_factory_
->CancelChannelCreation(base_channel_name_
);
333 void ChannelMultiplexer::CreateChannel(const std::string
& name
,
334 const ChannelCreatedCallback
& callback
) {
335 if (base_channel_
.get()) {
336 // Already have |base_channel_|. Create new multiplexed channel
338 callback
.Run(GetOrCreateChannel(name
)->CreateSocket());
339 } else if (!base_channel_
.get() && !base_channel_factory_
) {
340 // Fail synchronously if we failed to create |base_channel_|.
341 callback
.Run(nullptr);
343 // Still waiting for the |base_channel_|.
344 pending_channels_
.push_back(PendingChannel(name
, callback
));
346 // If this is the first multiplexed channel then create the base channel.
347 if (pending_channels_
.size() == 1U) {
348 base_channel_factory_
->CreateChannel(
350 base::Bind(&ChannelMultiplexer::OnBaseChannelReady
,
351 base::Unretained(this)));
356 void ChannelMultiplexer::CancelChannelCreation(const std::string
& name
) {
357 for (std::list
<PendingChannel
>::iterator it
= pending_channels_
.begin();
358 it
!= pending_channels_
.end(); ++it
) {
359 if (it
->name
== name
) {
360 pending_channels_
.erase(it
);
366 void ChannelMultiplexer::OnBaseChannelReady(
367 scoped_ptr
<P2PStreamSocket
> socket
) {
368 base_channel_factory_
= nullptr;
369 base_channel_
= socket
.Pass();
371 if (base_channel_
.get()) {
372 // Initialize reader and writer.
373 reader_
.StartReading(base_channel_
.get(),
374 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
375 base::Unretained(this)));
376 writer_
.Init(base::Bind(&P2PStreamSocket::Write
,
377 base::Unretained(base_channel_
.get())),
378 base::Bind(&ChannelMultiplexer::OnBaseChannelError
,
379 base::Unretained(this)));
382 DoCreatePendingChannels();
385 void ChannelMultiplexer::DoCreatePendingChannels() {
386 if (pending_channels_
.empty())
389 // Every time this function is called it connects a single channel and posts a
390 // separate task to connect other channels. This is necessary because the
391 // callback may destroy the multiplexer or somehow else modify
392 // |pending_channels_| list (e.g. call CancelChannelCreation()).
393 base::ThreadTaskRunnerHandle::Get()->PostTask(
394 FROM_HERE
, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels
,
395 weak_factory_
.GetWeakPtr()));
397 PendingChannel c
= pending_channels_
.front();
398 pending_channels_
.erase(pending_channels_
.begin());
399 scoped_ptr
<P2PStreamSocket
> socket
;
400 if (base_channel_
.get())
401 socket
= GetOrCreateChannel(c
.name
)->CreateSocket();
402 c
.callback
.Run(socket
.Pass());
405 ChannelMultiplexer::MuxChannel
* ChannelMultiplexer::GetOrCreateChannel(
406 const std::string
& name
) {
407 // Check if we already have a channel with the requested name.
408 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
409 if (it
!= channels_
.end())
412 // Create a new channel if we haven't found existing one.
413 MuxChannel
* channel
= new MuxChannel(this, name
, next_channel_id_
);
415 channels_
[channel
->name()] = channel
;
420 void ChannelMultiplexer::OnBaseChannelError(int error
) {
421 for (std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.begin();
422 it
!= channels_
.end(); ++it
) {
423 base::ThreadTaskRunnerHandle::Get()->PostTask(
425 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError
,
426 weak_factory_
.GetWeakPtr(), it
->second
->name(), error
));
430 void ChannelMultiplexer::NotifyBaseChannelError(const std::string
& name
,
432 std::map
<std::string
, MuxChannel
*>::iterator it
= channels_
.find(name
);
433 if (it
!= channels_
.end())
434 it
->second
->OnBaseChannelError(error
);
437 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr
<MultiplexPacket
> packet
,
438 const base::Closure
& done_task
) {
439 DCHECK(packet
->has_channel_id());
440 if (!packet
->has_channel_id()) {
441 LOG(ERROR
) << "Received packet without channel_id.";
446 int receive_id
= packet
->channel_id();
447 MuxChannel
* channel
= nullptr;
448 std::map
<int, MuxChannel
*>::iterator it
=
449 channels_by_receive_id_
.find(receive_id
);
450 if (it
!= channels_by_receive_id_
.end()) {
451 channel
= it
->second
;
453 // This is a new |channel_id| we haven't seen before. Look it up by name.
454 if (!packet
->has_channel_name()) {
455 LOG(ERROR
) << "Received packet with unknown channel_id and "
456 "without channel_name.";
460 channel
= GetOrCreateChannel(packet
->channel_name());
461 channel
->set_receive_id(receive_id
);
462 channels_by_receive_id_
[receive_id
] = channel
;
465 channel
->OnIncomingPacket(packet
.Pass(), done_task
);
468 bool ChannelMultiplexer::DoWrite(scoped_ptr
<MultiplexPacket
> packet
,
469 const base::Closure
& done_task
) {
470 return writer_
.Write(SerializeAndFrameMessage(*packet
), done_task
);
473 } // namespace protocol
474 } // namespace remoting