Roll src/third_party/WebKit d9c6159:8139f33 (svn 201974:201975)
[chromium-blink-merge.git] / remoting / protocol / channel_multiplexer.cc
blobff5f990184baa42dbd485f18d236d03eafa458d1
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
7 #include <string.h>
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "remoting/protocol/message_serialization.h"
18 #include "remoting/protocol/p2p_stream_socket.h"
20 namespace remoting {
21 namespace protocol {
23 namespace {
24 const int kChannelIdUnknown = -1;
25 const int kMaxPacketSize = 1024;
27 class PendingPacket {
28 public:
29 PendingPacket(scoped_ptr<MultiplexPacket> packet,
30 const base::Closure& done_task)
31 : packet(packet.Pass()),
32 done_task(done_task),
33 pos(0U) {
35 ~PendingPacket() {
36 done_task.Run();
39 bool is_empty() { return pos >= packet->data().size(); }
41 int Read(char* buffer, size_t size) {
42 size = std::min(size, packet->data().size() - pos);
43 memcpy(buffer, packet->data().data() + pos, size);
44 pos += size;
45 return size;
48 private:
49 scoped_ptr<MultiplexPacket> packet;
50 base::Closure done_task;
51 size_t pos;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
56 } // namespace
58 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
60 struct ChannelMultiplexer::PendingChannel {
61 PendingChannel(const std::string& name,
62 const ChannelCreatedCallback& callback)
63 : name(name), callback(callback) {
65 std::string name;
66 ChannelCreatedCallback callback;
69 class ChannelMultiplexer::MuxChannel {
70 public:
71 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
72 int send_id);
73 ~MuxChannel();
75 const std::string& name() { return name_; }
76 int receive_id() { return receive_id_; }
77 void set_receive_id(int id) { receive_id_ = id; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr<P2PStreamSocket> CreateSocket();
81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
82 const base::Closure& done_task);
83 void OnBaseChannelError(int error);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 void DoWrite(scoped_ptr<MultiplexPacket> packet,
88 const base::Closure& done_task);
89 int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
91 private:
92 ChannelMultiplexer* multiplexer_;
93 std::string name_;
94 int send_id_;
95 bool id_sent_;
96 int receive_id_;
97 MuxSocket* socket_;
98 std::list<PendingPacket*> pending_packets_;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
103 class ChannelMultiplexer::MuxSocket : public P2PStreamSocket,
104 public base::NonThreadSafe,
105 public base::SupportsWeakPtr<MuxSocket> {
106 public:
107 MuxSocket(MuxChannel* channel);
108 ~MuxSocket() override;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error);
112 void OnPacketReceived();
114 // P2PStreamSocket interface.
115 int Read(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
116 const net::CompletionCallback& callback) override;
117 int Write(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
118 const net::CompletionCallback& callback) override;
120 private:
121 MuxChannel* channel_;
123 int base_channel_error_ = net::OK;
125 net::CompletionCallback read_callback_;
126 scoped_refptr<net::IOBuffer> read_buffer_;
127 int read_buffer_size_;
129 bool write_pending_;
130 int write_result_;
131 net::CompletionCallback write_callback_;
133 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
137 ChannelMultiplexer::MuxChannel::MuxChannel(
138 ChannelMultiplexer* multiplexer,
139 const std::string& name,
140 int send_id)
141 : multiplexer_(multiplexer),
142 name_(name),
143 send_id_(send_id),
144 id_sent_(false),
145 receive_id_(kChannelIdUnknown),
146 socket_(nullptr) {
149 ChannelMultiplexer::MuxChannel::~MuxChannel() {
150 // Socket must be destroyed before the channel.
151 DCHECK(!socket_);
152 STLDeleteElements(&pending_packets_);
155 scoped_ptr<P2PStreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
156 DCHECK(!socket_); // Can't create more than one socket per channel.
157 scoped_ptr<MuxSocket> result(new MuxSocket(this));
158 socket_ = result.get();
159 return result.Pass();
162 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
163 scoped_ptr<MultiplexPacket> packet,
164 const base::Closure& done_task) {
165 DCHECK_EQ(packet->channel_id(), receive_id_);
166 if (packet->data().size() > 0) {
167 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
168 if (socket_) {
169 // Notify the socket that we have more data.
170 socket_->OnPacketReceived();
175 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
176 if (socket_)
177 socket_->OnBaseChannelError(error);
180 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
181 DCHECK(socket_);
182 socket_ = nullptr;
185 void ChannelMultiplexer::MuxChannel::DoWrite(
186 scoped_ptr<MultiplexPacket> packet,
187 const base::Closure& done_task) {
188 packet->set_channel_id(send_id_);
189 if (!id_sent_) {
190 packet->set_channel_name(name_);
191 id_sent_ = true;
193 multiplexer_->DoWrite(packet.Pass(), done_task);
196 int ChannelMultiplexer::MuxChannel::DoRead(
197 const scoped_refptr<net::IOBuffer>& buffer,
198 int buffer_len) {
199 int pos = 0;
200 while (buffer_len > 0 && !pending_packets_.empty()) {
201 DCHECK(!pending_packets_.front()->is_empty());
202 int result = pending_packets_.front()->Read(
203 buffer->data() + pos, buffer_len);
204 DCHECK_LE(result, buffer_len);
205 pos += result;
206 buffer_len -= pos;
207 if (pending_packets_.front()->is_empty()) {
208 delete pending_packets_.front();
209 pending_packets_.erase(pending_packets_.begin());
212 return pos;
215 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
216 : channel_(channel),
217 read_buffer_size_(0),
218 write_pending_(false),
219 write_result_(0) {
222 ChannelMultiplexer::MuxSocket::~MuxSocket() {
223 channel_->OnSocketDestroyed();
226 int ChannelMultiplexer::MuxSocket::Read(
227 const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
228 const net::CompletionCallback& callback) {
229 DCHECK(CalledOnValidThread());
230 DCHECK(read_callback_.is_null());
232 if (base_channel_error_ != net::OK)
233 return base_channel_error_;
235 int result = channel_->DoRead(buffer, buffer_len);
236 if (result == 0) {
237 read_buffer_ = buffer;
238 read_buffer_size_ = buffer_len;
239 read_callback_ = callback;
240 return net::ERR_IO_PENDING;
242 return result;
245 int ChannelMultiplexer::MuxSocket::Write(
246 const scoped_refptr<net::IOBuffer>& buffer, int buffer_len,
247 const net::CompletionCallback& callback) {
248 DCHECK(CalledOnValidThread());
249 DCHECK(write_callback_.is_null());
251 if (base_channel_error_ != net::OK)
252 return base_channel_error_;
254 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
255 size_t size = std::min(kMaxPacketSize, buffer_len);
256 packet->mutable_data()->assign(buffer->data(), size);
258 write_pending_ = true;
259 channel_->DoWrite(packet.Pass(), base::Bind(
260 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
262 // OnWriteComplete() might be called above synchronously.
263 if (write_pending_) {
264 DCHECK(write_callback_.is_null());
265 write_callback_ = callback;
266 write_result_ = size;
267 return net::ERR_IO_PENDING;
270 return size;
273 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
274 write_pending_ = false;
275 if (!write_callback_.is_null())
276 base::ResetAndReturn(&write_callback_).Run(write_result_);
280 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
281 base_channel_error_ = error;
283 // Here only one of the read and write callbacks is called if both of them are
284 // pending. Ideally both of them should be called in that case, but that would
285 // require the second one to be called asynchronously which would complicate
286 // this code. Channels handle read and write errors the same way (see
287 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
288 // callbacks is enough.
290 if (!read_callback_.is_null()) {
291 base::ResetAndReturn(&read_callback_).Run(error);
292 return;
295 if (!write_callback_.is_null())
296 base::ResetAndReturn(&write_callback_).Run(error);
299 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
300 if (!read_callback_.is_null()) {
301 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
302 read_buffer_ = nullptr;
303 DCHECK_GT(result, 0);
304 base::ResetAndReturn(&read_callback_).Run(result);
308 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
309 const std::string& base_channel_name)
310 : base_channel_factory_(factory),
311 base_channel_name_(base_channel_name),
312 next_channel_id_(0),
313 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
314 base::Unretained(this)),
315 &reader_),
316 weak_factory_(this) {
319 ChannelMultiplexer::~ChannelMultiplexer() {
320 DCHECK(pending_channels_.empty());
321 STLDeleteValues(&channels_);
323 // Cancel creation of the base channel if it hasn't finished.
324 if (base_channel_factory_)
325 base_channel_factory_->CancelChannelCreation(base_channel_name_);
328 void ChannelMultiplexer::CreateChannel(const std::string& name,
329 const ChannelCreatedCallback& callback) {
330 if (base_channel_.get()) {
331 // Already have |base_channel_|. Create new multiplexed channel
332 // synchronously.
333 callback.Run(GetOrCreateChannel(name)->CreateSocket());
334 } else if (!base_channel_.get() && !base_channel_factory_) {
335 // Fail synchronously if we failed to create |base_channel_|.
336 callback.Run(nullptr);
337 } else {
338 // Still waiting for the |base_channel_|.
339 pending_channels_.push_back(PendingChannel(name, callback));
341 // If this is the first multiplexed channel then create the base channel.
342 if (pending_channels_.size() == 1U) {
343 base_channel_factory_->CreateChannel(
344 base_channel_name_,
345 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
346 base::Unretained(this)));
351 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
352 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
353 it != pending_channels_.end(); ++it) {
354 if (it->name == name) {
355 pending_channels_.erase(it);
356 return;
361 void ChannelMultiplexer::OnBaseChannelReady(
362 scoped_ptr<P2PStreamSocket> socket) {
363 base_channel_factory_ = nullptr;
364 base_channel_ = socket.Pass();
366 if (base_channel_.get()) {
367 // Initialize reader and writer.
368 reader_.StartReading(base_channel_.get(),
369 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
370 base::Unretained(this)));
371 writer_.Init(base::Bind(&P2PStreamSocket::Write,
372 base::Unretained(base_channel_.get())),
373 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
374 base::Unretained(this)));
377 DoCreatePendingChannels();
380 void ChannelMultiplexer::DoCreatePendingChannels() {
381 if (pending_channels_.empty())
382 return;
384 // Every time this function is called it connects a single channel and posts a
385 // separate task to connect other channels. This is necessary because the
386 // callback may destroy the multiplexer or somehow else modify
387 // |pending_channels_| list (e.g. call CancelChannelCreation()).
388 base::ThreadTaskRunnerHandle::Get()->PostTask(
389 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
390 weak_factory_.GetWeakPtr()));
392 PendingChannel c = pending_channels_.front();
393 pending_channels_.erase(pending_channels_.begin());
394 scoped_ptr<P2PStreamSocket> socket;
395 if (base_channel_.get())
396 socket = GetOrCreateChannel(c.name)->CreateSocket();
397 c.callback.Run(socket.Pass());
400 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
401 const std::string& name) {
402 // Check if we already have a channel with the requested name.
403 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
404 if (it != channels_.end())
405 return it->second;
407 // Create a new channel if we haven't found existing one.
408 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
409 ++next_channel_id_;
410 channels_[channel->name()] = channel;
411 return channel;
415 void ChannelMultiplexer::OnBaseChannelError(int error) {
416 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
417 it != channels_.end(); ++it) {
418 base::ThreadTaskRunnerHandle::Get()->PostTask(
419 FROM_HERE,
420 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
421 weak_factory_.GetWeakPtr(), it->second->name(), error));
425 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
426 int error) {
427 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
428 if (it != channels_.end())
429 it->second->OnBaseChannelError(error);
432 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
433 const base::Closure& done_task) {
434 DCHECK(packet->has_channel_id());
435 if (!packet->has_channel_id()) {
436 LOG(ERROR) << "Received packet without channel_id.";
437 done_task.Run();
438 return;
441 int receive_id = packet->channel_id();
442 MuxChannel* channel = nullptr;
443 std::map<int, MuxChannel*>::iterator it =
444 channels_by_receive_id_.find(receive_id);
445 if (it != channels_by_receive_id_.end()) {
446 channel = it->second;
447 } else {
448 // This is a new |channel_id| we haven't seen before. Look it up by name.
449 if (!packet->has_channel_name()) {
450 LOG(ERROR) << "Received packet with unknown channel_id and "
451 "without channel_name.";
452 done_task.Run();
453 return;
455 channel = GetOrCreateChannel(packet->channel_name());
456 channel->set_receive_id(receive_id);
457 channels_by_receive_id_[receive_id] = channel;
460 channel->OnIncomingPacket(packet.Pass(), done_task);
463 void ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
464 const base::Closure& done_task) {
465 writer_.Write(SerializeAndFrameMessage(*packet), done_task);
468 } // namespace protocol
469 } // namespace remoting