Add MetricsService::WasLastShutdownClean()
[chromium-blink-merge.git] / remoting / protocol / channel_multiplexer.cc
blob8cdbf0d60af091469c6acb327211529f2e8f2444
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
7 #include <string.h>
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
12 #include "base/location.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/net_errors.h"
17 #include "net/socket/stream_socket.h"
18 #include "remoting/protocol/message_serialization.h"
20 namespace remoting {
21 namespace protocol {
23 namespace {
24 const int kChannelIdUnknown = -1;
25 const int kMaxPacketSize = 1024;
27 class PendingPacket {
28 public:
29 PendingPacket(scoped_ptr<MultiplexPacket> packet,
30 const base::Closure& done_task)
31 : packet(packet.Pass()),
32 done_task(done_task),
33 pos(0U) {
35 ~PendingPacket() {
36 done_task.Run();
39 bool is_empty() { return pos >= packet->data().size(); }
41 int Read(char* buffer, size_t size) {
42 size = std::min(size, packet->data().size() - pos);
43 memcpy(buffer, packet->data().data() + pos, size);
44 pos += size;
45 return size;
48 private:
49 scoped_ptr<MultiplexPacket> packet;
50 base::Closure done_task;
51 size_t pos;
53 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
56 } // namespace
58 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
60 struct ChannelMultiplexer::PendingChannel {
61 PendingChannel(const std::string& name,
62 const ChannelCreatedCallback& callback)
63 : name(name), callback(callback) {
65 std::string name;
66 ChannelCreatedCallback callback;
69 class ChannelMultiplexer::MuxChannel {
70 public:
71 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
72 int send_id);
73 ~MuxChannel();
75 const std::string& name() { return name_; }
76 int receive_id() { return receive_id_; }
77 void set_receive_id(int id) { receive_id_ = id; }
79 // Called by ChannelMultiplexer.
80 scoped_ptr<net::StreamSocket> CreateSocket();
81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
82 const base::Closure& done_task);
83 void OnBaseChannelError(int error);
85 // Called by MuxSocket.
86 void OnSocketDestroyed();
87 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
88 const base::Closure& done_task);
89 int DoRead(net::IOBuffer* buffer, int buffer_len);
91 private:
92 ChannelMultiplexer* multiplexer_;
93 std::string name_;
94 int send_id_;
95 bool id_sent_;
96 int receive_id_;
97 MuxSocket* socket_;
98 std::list<PendingPacket*> pending_packets_;
100 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
103 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
104 public base::NonThreadSafe,
105 public base::SupportsWeakPtr<MuxSocket> {
106 public:
107 MuxSocket(MuxChannel* channel);
108 ~MuxSocket() override;
110 void OnWriteComplete();
111 void OnBaseChannelError(int error);
112 void OnPacketReceived();
114 // net::StreamSocket interface.
115 int Read(net::IOBuffer* buffer,
116 int buffer_len,
117 const net::CompletionCallback& callback) override;
118 int Write(net::IOBuffer* buffer,
119 int buffer_len,
120 const net::CompletionCallback& callback) override;
122 int SetReceiveBufferSize(int32 size) override {
123 NOTIMPLEMENTED();
124 return net::ERR_NOT_IMPLEMENTED;
126 int SetSendBufferSize(int32 size) override {
127 NOTIMPLEMENTED();
128 return net::ERR_NOT_IMPLEMENTED;
131 int Connect(const net::CompletionCallback& callback) override {
132 NOTIMPLEMENTED();
133 return net::ERR_NOT_IMPLEMENTED;
135 void Disconnect() override { NOTIMPLEMENTED(); }
136 bool IsConnected() const override {
137 NOTIMPLEMENTED();
138 return true;
140 bool IsConnectedAndIdle() const override {
141 NOTIMPLEMENTED();
142 return false;
144 int GetPeerAddress(net::IPEndPoint* address) const override {
145 NOTIMPLEMENTED();
146 return net::ERR_NOT_IMPLEMENTED;
148 int GetLocalAddress(net::IPEndPoint* address) const override {
149 NOTIMPLEMENTED();
150 return net::ERR_NOT_IMPLEMENTED;
152 const net::BoundNetLog& NetLog() const override {
153 NOTIMPLEMENTED();
154 return net_log_;
156 void SetSubresourceSpeculation() override { NOTIMPLEMENTED(); }
157 void SetOmniboxSpeculation() override { NOTIMPLEMENTED(); }
158 bool WasEverUsed() const override { return true; }
159 bool UsingTCPFastOpen() const override { return false; }
160 bool WasNpnNegotiated() const override { return false; }
161 net::NextProto GetNegotiatedProtocol() const override {
162 return net::kProtoUnknown;
164 bool GetSSLInfo(net::SSLInfo* ssl_info) override {
165 NOTIMPLEMENTED();
166 return false;
168 void GetConnectionAttempts(net::ConnectionAttempts* out) const override {
169 out->clear();
171 void ClearConnectionAttempts() override {}
172 void AddConnectionAttempts(const net::ConnectionAttempts& attempts) override {
175 private:
176 MuxChannel* channel_;
178 int base_channel_error_ = net::OK;
180 net::CompletionCallback read_callback_;
181 scoped_refptr<net::IOBuffer> read_buffer_;
182 int read_buffer_size_;
184 bool write_pending_;
185 int write_result_;
186 net::CompletionCallback write_callback_;
188 net::BoundNetLog net_log_;
190 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
194 ChannelMultiplexer::MuxChannel::MuxChannel(
195 ChannelMultiplexer* multiplexer,
196 const std::string& name,
197 int send_id)
198 : multiplexer_(multiplexer),
199 name_(name),
200 send_id_(send_id),
201 id_sent_(false),
202 receive_id_(kChannelIdUnknown),
203 socket_(nullptr) {
206 ChannelMultiplexer::MuxChannel::~MuxChannel() {
207 // Socket must be destroyed before the channel.
208 DCHECK(!socket_);
209 STLDeleteElements(&pending_packets_);
212 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
213 DCHECK(!socket_); // Can't create more than one socket per channel.
214 scoped_ptr<MuxSocket> result(new MuxSocket(this));
215 socket_ = result.get();
216 return result.Pass();
219 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
220 scoped_ptr<MultiplexPacket> packet,
221 const base::Closure& done_task) {
222 DCHECK_EQ(packet->channel_id(), receive_id_);
223 if (packet->data().size() > 0) {
224 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
225 if (socket_) {
226 // Notify the socket that we have more data.
227 socket_->OnPacketReceived();
232 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
233 if (socket_)
234 socket_->OnBaseChannelError(error);
237 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
238 DCHECK(socket_);
239 socket_ = nullptr;
242 bool ChannelMultiplexer::MuxChannel::DoWrite(
243 scoped_ptr<MultiplexPacket> packet,
244 const base::Closure& done_task) {
245 packet->set_channel_id(send_id_);
246 if (!id_sent_) {
247 packet->set_channel_name(name_);
248 id_sent_ = true;
250 return multiplexer_->DoWrite(packet.Pass(), done_task);
253 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
254 int buffer_len) {
255 int pos = 0;
256 while (buffer_len > 0 && !pending_packets_.empty()) {
257 DCHECK(!pending_packets_.front()->is_empty());
258 int result = pending_packets_.front()->Read(
259 buffer->data() + pos, buffer_len);
260 DCHECK_LE(result, buffer_len);
261 pos += result;
262 buffer_len -= pos;
263 if (pending_packets_.front()->is_empty()) {
264 delete pending_packets_.front();
265 pending_packets_.erase(pending_packets_.begin());
268 return pos;
271 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
272 : channel_(channel),
273 read_buffer_size_(0),
274 write_pending_(false),
275 write_result_(0) {
278 ChannelMultiplexer::MuxSocket::~MuxSocket() {
279 channel_->OnSocketDestroyed();
282 int ChannelMultiplexer::MuxSocket::Read(
283 net::IOBuffer* buffer, int buffer_len,
284 const net::CompletionCallback& callback) {
285 DCHECK(CalledOnValidThread());
286 DCHECK(read_callback_.is_null());
288 if (base_channel_error_ != net::OK)
289 return base_channel_error_;
291 int result = channel_->DoRead(buffer, buffer_len);
292 if (result == 0) {
293 read_buffer_ = buffer;
294 read_buffer_size_ = buffer_len;
295 read_callback_ = callback;
296 return net::ERR_IO_PENDING;
298 return result;
301 int ChannelMultiplexer::MuxSocket::Write(
302 net::IOBuffer* buffer, int buffer_len,
303 const net::CompletionCallback& callback) {
304 DCHECK(CalledOnValidThread());
305 DCHECK(write_callback_.is_null());
307 if (base_channel_error_ != net::OK)
308 return base_channel_error_;
310 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
311 size_t size = std::min(kMaxPacketSize, buffer_len);
312 packet->mutable_data()->assign(buffer->data(), size);
314 write_pending_ = true;
315 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
316 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
318 if (!result) {
319 // Cannot complete the write, e.g. if the connection has been terminated.
320 return net::ERR_FAILED;
323 // OnWriteComplete() might be called above synchronously.
324 if (write_pending_) {
325 DCHECK(write_callback_.is_null());
326 write_callback_ = callback;
327 write_result_ = size;
328 return net::ERR_IO_PENDING;
331 return size;
334 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
335 write_pending_ = false;
336 if (!write_callback_.is_null())
337 base::ResetAndReturn(&write_callback_).Run(write_result_);
341 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
342 base_channel_error_ = error;
344 // Here only one of the read and write callbacks is called if both of them are
345 // pending. Ideally both of them should be called in that case, but that would
346 // require the second one to be called asynchronously which would complicate
347 // this code. Channels handle read and write errors the same way (see
348 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
349 // callbacks is enough.
351 if (!read_callback_.is_null()) {
352 base::ResetAndReturn(&read_callback_).Run(error);
353 return;
356 if (!write_callback_.is_null())
357 base::ResetAndReturn(&write_callback_).Run(error);
360 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
361 if (!read_callback_.is_null()) {
362 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
363 read_buffer_ = nullptr;
364 DCHECK_GT(result, 0);
365 base::ResetAndReturn(&read_callback_).Run(result);
369 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
370 const std::string& base_channel_name)
371 : base_channel_factory_(factory),
372 base_channel_name_(base_channel_name),
373 next_channel_id_(0),
374 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
375 base::Unretained(this)),
376 &reader_),
377 weak_factory_(this) {
380 ChannelMultiplexer::~ChannelMultiplexer() {
381 DCHECK(pending_channels_.empty());
382 STLDeleteValues(&channels_);
384 // Cancel creation of the base channel if it hasn't finished.
385 if (base_channel_factory_)
386 base_channel_factory_->CancelChannelCreation(base_channel_name_);
389 void ChannelMultiplexer::CreateChannel(const std::string& name,
390 const ChannelCreatedCallback& callback) {
391 if (base_channel_.get()) {
392 // Already have |base_channel_|. Create new multiplexed channel
393 // synchronously.
394 callback.Run(GetOrCreateChannel(name)->CreateSocket());
395 } else if (!base_channel_.get() && !base_channel_factory_) {
396 // Fail synchronously if we failed to create |base_channel_|.
397 callback.Run(nullptr);
398 } else {
399 // Still waiting for the |base_channel_|.
400 pending_channels_.push_back(PendingChannel(name, callback));
402 // If this is the first multiplexed channel then create the base channel.
403 if (pending_channels_.size() == 1U) {
404 base_channel_factory_->CreateChannel(
405 base_channel_name_,
406 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
407 base::Unretained(this)));
412 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
413 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
414 it != pending_channels_.end(); ++it) {
415 if (it->name == name) {
416 pending_channels_.erase(it);
417 return;
422 void ChannelMultiplexer::OnBaseChannelReady(
423 scoped_ptr<net::StreamSocket> socket) {
424 base_channel_factory_ = nullptr;
425 base_channel_ = socket.Pass();
427 if (base_channel_.get()) {
428 // Initialize reader and writer.
429 reader_.StartReading(base_channel_.get(),
430 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
431 base::Unretained(this)));
432 writer_.Init(base_channel_.get(),
433 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
434 base::Unretained(this)));
437 DoCreatePendingChannels();
440 void ChannelMultiplexer::DoCreatePendingChannels() {
441 if (pending_channels_.empty())
442 return;
444 // Every time this function is called it connects a single channel and posts a
445 // separate task to connect other channels. This is necessary because the
446 // callback may destroy the multiplexer or somehow else modify
447 // |pending_channels_| list (e.g. call CancelChannelCreation()).
448 base::ThreadTaskRunnerHandle::Get()->PostTask(
449 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
450 weak_factory_.GetWeakPtr()));
452 PendingChannel c = pending_channels_.front();
453 pending_channels_.erase(pending_channels_.begin());
454 scoped_ptr<net::StreamSocket> socket;
455 if (base_channel_.get())
456 socket = GetOrCreateChannel(c.name)->CreateSocket();
457 c.callback.Run(socket.Pass());
460 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
461 const std::string& name) {
462 // Check if we already have a channel with the requested name.
463 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
464 if (it != channels_.end())
465 return it->second;
467 // Create a new channel if we haven't found existing one.
468 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
469 ++next_channel_id_;
470 channels_[channel->name()] = channel;
471 return channel;
475 void ChannelMultiplexer::OnBaseChannelError(int error) {
476 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
477 it != channels_.end(); ++it) {
478 base::ThreadTaskRunnerHandle::Get()->PostTask(
479 FROM_HERE,
480 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
481 weak_factory_.GetWeakPtr(), it->second->name(), error));
485 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
486 int error) {
487 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
488 if (it != channels_.end())
489 it->second->OnBaseChannelError(error);
492 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
493 const base::Closure& done_task) {
494 DCHECK(packet->has_channel_id());
495 if (!packet->has_channel_id()) {
496 LOG(ERROR) << "Received packet without channel_id.";
497 done_task.Run();
498 return;
501 int receive_id = packet->channel_id();
502 MuxChannel* channel = nullptr;
503 std::map<int, MuxChannel*>::iterator it =
504 channels_by_receive_id_.find(receive_id);
505 if (it != channels_by_receive_id_.end()) {
506 channel = it->second;
507 } else {
508 // This is a new |channel_id| we haven't seen before. Look it up by name.
509 if (!packet->has_channel_name()) {
510 LOG(ERROR) << "Received packet with unknown channel_id and "
511 "without channel_name.";
512 done_task.Run();
513 return;
515 channel = GetOrCreateChannel(packet->channel_name());
516 channel->set_receive_id(receive_id);
517 channels_by_receive_id_[receive_id] = channel;
520 channel->OnIncomingPacket(packet.Pass(), done_task);
523 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
524 const base::Closure& done_task) {
525 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
528 } // namespace protocol
529 } // namespace remoting