Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / remoting / protocol / channel_multiplexer.cc
blobe19a9a1229227a4f1ea8363f1efbbdfc80df53ec
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
7 #include <string.h>
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "remoting/protocol/message_serialization.h"
18 #include "remoting/protocol/p2p_stream_socket.h"
20 namespace remoting {
21 namespace protocol {
23 namespace {
24 const int kChannelIdUnknown = -1;
25 const int kMaxPacketSize = 1024;
27 class PendingPacket {
28 public:
29 PendingPacket(scoped_ptr<MultiplexPacket> packet,
30 const base::Closure& done_task)
31 : packet(packet.Pass()),
32 done_task(done_task),
33 pos(0U) {
35 ~PendingPacket() {
36 done_task.Run();
39 bool is_empty() { return pos >= packet->data().size(); }
41 int Read(char* buffer, size_t size) {
42 size = std::min(size, packet->data().size() - pos);
43 memcpy(buffer, packet->data().data() + pos, size);
44 pos += size;
45 return size;
48 private:
49 scoped_ptr<MultiplexPacket> packet;
50 base::Closure done_task;
51 size_t pos;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
56 } // namespace
58 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
60 struct ChannelMultiplexer::PendingChannel {
61 PendingChannel(const std::string& name,
62 const ChannelCreatedCallback& callback)
63 : name(name), callback(callback) {
65 std::string name;
66 ChannelCreatedCallback callback;
69 class ChannelMultiplexer::MuxChannel {
70 public:
71 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
72 int send_id);
73 ~MuxChannel();
75 const std::string& name() { return name_; }
76 int receive_id() { return receive_id_; }
77 void set_receive_id(int id) { receive_id_ = id; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr<P2PStreamSocket> CreateSocket();
81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
82 const base::Closure& done_task);
83 void OnBaseChannelError(int error);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
88 const base::Closure& done_task);
89 int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
91 private:
92 ChannelMultiplexer* multiplexer_;
93 std::string name_;
94 int send_id_;
95 bool id_sent_;
96 int receive_id_;
97 MuxSocket* socket_;
98 std::list<PendingPacket*> pending_packets_;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
103 class ChannelMultiplexer::MuxSocket : public P2PStreamSocket,
104 public base::NonThreadSafe,
105 public base::SupportsWeakPtr<MuxSocket> {
106 public:
107 MuxSocket(MuxChannel* channel);
108 ~MuxSocket() override;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error);
112 void OnPacketReceived();
114 // P2PStreamSocket interface.
115 int Read(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
116 const net::CompletionCallback& callback) override;
117 int Write(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
118 const net::CompletionCallback& callback) override;
120 private:
121 MuxChannel* channel_;
123 int base_channel_error_ = net::OK;
125 net::CompletionCallback read_callback_;
126 scoped_refptr<net::IOBuffer> read_buffer_;
127 int read_buffer_size_;
129 bool write_pending_;
130 int write_result_;
131 net::CompletionCallback write_callback_;
133 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
137 ChannelMultiplexer::MuxChannel::MuxChannel(
138 ChannelMultiplexer* multiplexer,
139 const std::string& name,
140 int send_id)
141 : multiplexer_(multiplexer),
142 name_(name),
143 send_id_(send_id),
144 id_sent_(false),
145 receive_id_(kChannelIdUnknown),
146 socket_(nullptr) {
149 ChannelMultiplexer::MuxChannel::~MuxChannel() {
150 // Socket must be destroyed before the channel.
151 DCHECK(!socket_);
152 STLDeleteElements(&pending_packets_);
155 scoped_ptr<P2PStreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
156 DCHECK(!socket_); // Can't create more than one socket per channel.
157 scoped_ptr<MuxSocket> result(new MuxSocket(this));
158 socket_ = result.get();
159 return result.Pass();
162 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
163 scoped_ptr<MultiplexPacket> packet,
164 const base::Closure& done_task) {
165 DCHECK_EQ(packet->channel_id(), receive_id_);
166 if (packet->data().size() > 0) {
167 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
168 if (socket_) {
169 // Notify the socket that we have more data.
170 socket_->OnPacketReceived();
175 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
176 if (socket_)
177 socket_->OnBaseChannelError(error);
180 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
181 DCHECK(socket_);
182 socket_ = nullptr;
185 bool ChannelMultiplexer::MuxChannel::DoWrite(
186 scoped_ptr<MultiplexPacket> packet,
187 const base::Closure& done_task) {
188 packet->set_channel_id(send_id_);
189 if (!id_sent_) {
190 packet->set_channel_name(name_);
191 id_sent_ = true;
193 return multiplexer_->DoWrite(packet.Pass(), done_task);
196 int ChannelMultiplexer::MuxChannel::DoRead(
197 const scoped_refptr<net::IOBuffer>& buffer,
198 int buffer_len) {
199 int pos = 0;
200 while (buffer_len > 0 && !pending_packets_.empty()) {
201 DCHECK(!pending_packets_.front()->is_empty());
202 int result = pending_packets_.front()->Read(
203 buffer->data() + pos, buffer_len);
204 DCHECK_LE(result, buffer_len);
205 pos += result;
206 buffer_len -= pos;
207 if (pending_packets_.front()->is_empty()) {
208 delete pending_packets_.front();
209 pending_packets_.erase(pending_packets_.begin());
212 return pos;
215 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
216 : channel_(channel),
217 read_buffer_size_(0),
218 write_pending_(false),
219 write_result_(0) {
222 ChannelMultiplexer::MuxSocket::~MuxSocket() {
223 channel_->OnSocketDestroyed();
226 int ChannelMultiplexer::MuxSocket::Read(
227 const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
228 const net::CompletionCallback& callback) {
229 DCHECK(CalledOnValidThread());
230 DCHECK(read_callback_.is_null());
232 if (base_channel_error_ != net::OK)
233 return base_channel_error_;
235 int result = channel_->DoRead(buffer, buffer_len);
236 if (result == 0) {
237 read_buffer_ = buffer;
238 read_buffer_size_ = buffer_len;
239 read_callback_ = callback;
240 return net::ERR_IO_PENDING;
242 return result;
245 int ChannelMultiplexer::MuxSocket::Write(
246 const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
247 const net::CompletionCallback& callback) {
248 DCHECK(CalledOnValidThread());
249 DCHECK(write_callback_.is_null());
251 if (base_channel_error_ != net::OK)
252 return base_channel_error_;
254 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
255 size_t size = std::min(kMaxPacketSize, buffer_len);
256 packet->mutable_data()->assign(buffer->data(), size);
258 write_pending_ = true;
259 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
260 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
262 if (!result) {
263 // Cannot complete the write, e.g. if the connection has been terminated.
264 return net::ERR_FAILED;
267 // OnWriteComplete() might be called above synchronously.
268 if (write_pending_) {
269 DCHECK(write_callback_.is_null());
270 write_callback_ = callback;
271 write_result_ = size;
272 return net::ERR_IO_PENDING;
275 return size;
278 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
279 write_pending_ = false;
280 if (!write_callback_.is_null())
281 base::ResetAndReturn(&write_callback_).Run(write_result_);
285 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
286 base_channel_error_ = error;
288 // Here only one of the read and write callbacks is called if both of them are
289 // pending. Ideally both of them should be called in that case, but that would
290 // require the second one to be called asynchronously which would complicate
291 // this code. Channels handle read and write errors the same way (see
292 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
293 // callbacks is enough.
295 if (!read_callback_.is_null()) {
296 base::ResetAndReturn(&read_callback_).Run(error);
297 return;
300 if (!write_callback_.is_null())
301 base::ResetAndReturn(&write_callback_).Run(error);
304 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
305 if (!read_callback_.is_null()) {
306 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
307 read_buffer_ = nullptr;
308 DCHECK_GT(result, 0);
309 base::ResetAndReturn(&read_callback_).Run(result);
313 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
314 const std::string& base_channel_name)
315 : base_channel_factory_(factory),
316 base_channel_name_(base_channel_name),
317 next_channel_id_(0),
318 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
319 base::Unretained(this)),
320 &reader_),
321 weak_factory_(this) {
324 ChannelMultiplexer::~ChannelMultiplexer() {
325 DCHECK(pending_channels_.empty());
326 STLDeleteValues(&channels_);
328 // Cancel creation of the base channel if it hasn't finished.
329 if (base_channel_factory_)
330 base_channel_factory_->CancelChannelCreation(base_channel_name_);
333 void ChannelMultiplexer::CreateChannel(const std::string& name,
334 const ChannelCreatedCallback& callback) {
335 if (base_channel_.get()) {
336 // Already have |base_channel_|. Create new multiplexed channel
337 // synchronously.
338 callback.Run(GetOrCreateChannel(name)->CreateSocket());
339 } else if (!base_channel_.get() && !base_channel_factory_) {
340 // Fail synchronously if we failed to create |base_channel_|.
341 callback.Run(nullptr);
342 } else {
343 // Still waiting for the |base_channel_|.
344 pending_channels_.push_back(PendingChannel(name, callback));
346 // If this is the first multiplexed channel then create the base channel.
347 if (pending_channels_.size() == 1U) {
348 base_channel_factory_->CreateChannel(
349 base_channel_name_,
350 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
351 base::Unretained(this)));
356 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
357 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
358 it != pending_channels_.end(); ++it) {
359 if (it->name == name) {
360 pending_channels_.erase(it);
361 return;
366 void ChannelMultiplexer::OnBaseChannelReady(
367 scoped_ptr<P2PStreamSocket> socket) {
368 base_channel_factory_ = nullptr;
369 base_channel_ = socket.Pass();
371 if (base_channel_.get()) {
372 // Initialize reader and writer.
373 reader_.StartReading(base_channel_.get(),
374 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
375 base::Unretained(this)));
376 writer_.Init(base::Bind(&P2PStreamSocket::Write,
377 base::Unretained(base_channel_.get())),
378 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
379 base::Unretained(this)));
382 DoCreatePendingChannels();
385 void ChannelMultiplexer::DoCreatePendingChannels() {
386 if (pending_channels_.empty())
387 return;
389 // Every time this function is called it connects a single channel and posts a
390 // separate task to connect other channels. This is necessary because the
391 // callback may destroy the multiplexer or somehow else modify
392 // |pending_channels_| list (e.g. call CancelChannelCreation()).
393 base::ThreadTaskRunnerHandle::Get()->PostTask(
394 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
395 weak_factory_.GetWeakPtr()));
397 PendingChannel c = pending_channels_.front();
398 pending_channels_.erase(pending_channels_.begin());
399 scoped_ptr<P2PStreamSocket> socket;
400 if (base_channel_.get())
401 socket = GetOrCreateChannel(c.name)->CreateSocket();
402 c.callback.Run(socket.Pass());
405 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
406 const std::string& name) {
407 // Check if we already have a channel with the requested name.
408 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
409 if (it != channels_.end())
410 return it->second;
412 // Create a new channel if we haven't found existing one.
413 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
414 ++next_channel_id_;
415 channels_[channel->name()] = channel;
416 return channel;
420 void ChannelMultiplexer::OnBaseChannelError(int error) {
421 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
422 it != channels_.end(); ++it) {
423 base::ThreadTaskRunnerHandle::Get()->PostTask(
424 FROM_HERE,
425 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
426 weak_factory_.GetWeakPtr(), it->second->name(), error));
430 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
431 int error) {
432 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
433 if (it != channels_.end())
434 it->second->OnBaseChannelError(error);
437 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
438 const base::Closure& done_task) {
439 DCHECK(packet->has_channel_id());
440 if (!packet->has_channel_id()) {
441 LOG(ERROR) << "Received packet without channel_id.";
442 done_task.Run();
443 return;
446 int receive_id = packet->channel_id();
447 MuxChannel* channel = nullptr;
448 std::map<int, MuxChannel*>::iterator it =
449 channels_by_receive_id_.find(receive_id);
450 if (it != channels_by_receive_id_.end()) {
451 channel = it->second;
452 } else {
453 // This is a new |channel_id| we haven't seen before. Look it up by name.
454 if (!packet->has_channel_name()) {
455 LOG(ERROR) << "Received packet with unknown channel_id and "
456 "without channel_name.";
457 done_task.Run();
458 return;
460 channel = GetOrCreateChannel(packet->channel_name());
461 channel->set_receive_id(receive_id);
462 channels_by_receive_id_[receive_id] = channel;
465 channel->OnIncomingPacket(packet.Pass(), done_task);
468 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
469 const base::Closure& done_task) {
470 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
473 } // namespace protocol
474 } // namespace remoting