Re-land: C++ readability review
[chromium-blink-merge.git] / remoting / protocol / channel_multiplexer.cc
blob8d885fe28f05765d7c697f01138a0f86688f2e2f
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
7 #include <string.h>
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/message_serialization.h"
19 namespace remoting {
20 namespace protocol {
22 namespace {
23 const int kChannelIdUnknown = -1;
24 const int kMaxPacketSize = 1024;
26 class PendingPacket {
27 public:
28 PendingPacket(scoped_ptr<MultiplexPacket> packet,
29 const base::Closure& done_task)
30 : packet(packet.Pass()),
31 done_task(done_task),
32 pos(0U) {
34 ~PendingPacket() {
35 done_task.Run();
38 bool is_empty() { return pos >= packet->data().size(); }
40 int Read(char* buffer, size_t size) {
41 size = std::min(size, packet->data().size() - pos);
42 memcpy(buffer, packet->data().data() + pos, size);
43 pos += size;
44 return size;
47 private:
48 scoped_ptr<MultiplexPacket> packet;
49 base::Closure done_task;
50 size_t pos;
52 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
55 } // namespace
57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
59 struct ChannelMultiplexer::PendingChannel {
60 PendingChannel(const std::string& name,
61 const ChannelCreatedCallback& callback)
62 : name(name), callback(callback) {
64 std::string name;
65 ChannelCreatedCallback callback;
68 class ChannelMultiplexer::MuxChannel {
69 public:
70 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71 int send_id);
72 ~MuxChannel();
74 const std::string& name() { return name_; }
75 int receive_id() { return receive_id_; }
76 void set_receive_id(int id) { receive_id_ = id; }
78 // Called by ChannelMultiplexer.
79 scoped_ptr<net::StreamSocket> CreateSocket();
80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81 const base::Closure& done_task);
82 void OnWriteFailed();
84 // Called by MuxSocket.
85 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87 const base::Closure& done_task);
88 int DoRead(net::IOBuffer* buffer, int buffer_len);
90 private:
91 ChannelMultiplexer* multiplexer_;
92 std::string name_;
93 int send_id_;
94 bool id_sent_;
95 int receive_id_;
96 MuxSocket* socket_;
97 std::list<PendingPacket*> pending_packets_;
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103 public base::NonThreadSafe,
104 public base::SupportsWeakPtr<MuxSocket> {
105 public:
106 MuxSocket(MuxChannel* channel);
107 ~MuxSocket() override;
109 void OnWriteComplete();
110 void OnWriteFailed();
111 void OnPacketReceived();
113 // net::StreamSocket interface.
114 int Read(net::IOBuffer* buffer,
115 int buffer_len,
116 const net::CompletionCallback& callback) override;
117 int Write(net::IOBuffer* buffer,
118 int buffer_len,
119 const net::CompletionCallback& callback) override;
121 int SetReceiveBufferSize(int32 size) override {
122 NOTIMPLEMENTED();
123 return net::ERR_NOT_IMPLEMENTED;
125 int SetSendBufferSize(int32 size) override {
126 NOTIMPLEMENTED();
127 return net::ERR_NOT_IMPLEMENTED;
130 int Connect(const net::CompletionCallback& callback) override {
131 NOTIMPLEMENTED();
132 return net::ERR_NOT_IMPLEMENTED;
134 void Disconnect() override { NOTIMPLEMENTED(); }
135 bool IsConnected() const override {
136 NOTIMPLEMENTED();
137 return true;
139 bool IsConnectedAndIdle() const override {
140 NOTIMPLEMENTED();
141 return false;
143 int GetPeerAddress(net::IPEndPoint* address) const override {
144 NOTIMPLEMENTED();
145 return net::ERR_NOT_IMPLEMENTED;
147 int GetLocalAddress(net::IPEndPoint* address) const override {
148 NOTIMPLEMENTED();
149 return net::ERR_NOT_IMPLEMENTED;
151 const net::BoundNetLog& NetLog() const override {
152 NOTIMPLEMENTED();
153 return net_log_;
155 void SetSubresourceSpeculation() override { NOTIMPLEMENTED(); }
156 void SetOmniboxSpeculation() override { NOTIMPLEMENTED(); }
157 bool WasEverUsed() const override { return true; }
158 bool UsingTCPFastOpen() const override { return false; }
159 bool WasNpnNegotiated() const override { return false; }
160 net::NextProto GetNegotiatedProtocol() const override {
161 return net::kProtoUnknown;
163 bool GetSSLInfo(net::SSLInfo* ssl_info) override {
164 NOTIMPLEMENTED();
165 return false;
168 private:
169 MuxChannel* channel_;
171 net::CompletionCallback read_callback_;
172 scoped_refptr<net::IOBuffer> read_buffer_;
173 int read_buffer_size_;
175 bool write_pending_;
176 int write_result_;
177 net::CompletionCallback write_callback_;
179 net::BoundNetLog net_log_;
181 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
185 ChannelMultiplexer::MuxChannel::MuxChannel(
186 ChannelMultiplexer* multiplexer,
187 const std::string& name,
188 int send_id)
189 : multiplexer_(multiplexer),
190 name_(name),
191 send_id_(send_id),
192 id_sent_(false),
193 receive_id_(kChannelIdUnknown),
194 socket_(nullptr) {
197 ChannelMultiplexer::MuxChannel::~MuxChannel() {
198 // Socket must be destroyed before the channel.
199 DCHECK(!socket_);
200 STLDeleteElements(&pending_packets_);
203 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
204 DCHECK(!socket_); // Can't create more than one socket per channel.
205 scoped_ptr<MuxSocket> result(new MuxSocket(this));
206 socket_ = result.get();
207 return result.Pass();
210 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
211 scoped_ptr<MultiplexPacket> packet,
212 const base::Closure& done_task) {
213 DCHECK_EQ(packet->channel_id(), receive_id_);
214 if (packet->data().size() > 0) {
215 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
216 if (socket_) {
217 // Notify the socket that we have more data.
218 socket_->OnPacketReceived();
223 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
224 if (socket_)
225 socket_->OnWriteFailed();
228 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
229 DCHECK(socket_);
230 socket_ = nullptr;
233 bool ChannelMultiplexer::MuxChannel::DoWrite(
234 scoped_ptr<MultiplexPacket> packet,
235 const base::Closure& done_task) {
236 packet->set_channel_id(send_id_);
237 if (!id_sent_) {
238 packet->set_channel_name(name_);
239 id_sent_ = true;
241 return multiplexer_->DoWrite(packet.Pass(), done_task);
244 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
245 int buffer_len) {
246 int pos = 0;
247 while (buffer_len > 0 && !pending_packets_.empty()) {
248 DCHECK(!pending_packets_.front()->is_empty());
249 int result = pending_packets_.front()->Read(
250 buffer->data() + pos, buffer_len);
251 DCHECK_LE(result, buffer_len);
252 pos += result;
253 buffer_len -= pos;
254 if (pending_packets_.front()->is_empty()) {
255 delete pending_packets_.front();
256 pending_packets_.erase(pending_packets_.begin());
259 return pos;
262 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
263 : channel_(channel),
264 read_buffer_size_(0),
265 write_pending_(false),
266 write_result_(0) {
269 ChannelMultiplexer::MuxSocket::~MuxSocket() {
270 channel_->OnSocketDestroyed();
273 int ChannelMultiplexer::MuxSocket::Read(
274 net::IOBuffer* buffer, int buffer_len,
275 const net::CompletionCallback& callback) {
276 DCHECK(CalledOnValidThread());
277 DCHECK(read_callback_.is_null());
279 int result = channel_->DoRead(buffer, buffer_len);
280 if (result == 0) {
281 read_buffer_ = buffer;
282 read_buffer_size_ = buffer_len;
283 read_callback_ = callback;
284 return net::ERR_IO_PENDING;
286 return result;
289 int ChannelMultiplexer::MuxSocket::Write(
290 net::IOBuffer* buffer, int buffer_len,
291 const net::CompletionCallback& callback) {
292 DCHECK(CalledOnValidThread());
294 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
295 size_t size = std::min(kMaxPacketSize, buffer_len);
296 packet->mutable_data()->assign(buffer->data(), size);
298 write_pending_ = true;
299 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
300 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
302 if (!result) {
303 // Cannot complete the write, e.g. if the connection has been terminated.
304 return net::ERR_FAILED;
307 // OnWriteComplete() might be called above synchronously.
308 if (write_pending_) {
309 DCHECK(write_callback_.is_null());
310 write_callback_ = callback;
311 write_result_ = size;
312 return net::ERR_IO_PENDING;
315 return size;
318 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
319 write_pending_ = false;
320 if (!write_callback_.is_null()) {
321 net::CompletionCallback cb;
322 std::swap(cb, write_callback_);
323 cb.Run(write_result_);
327 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
328 if (!write_callback_.is_null()) {
329 net::CompletionCallback cb;
330 std::swap(cb, write_callback_);
331 cb.Run(net::ERR_FAILED);
335 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
336 if (!read_callback_.is_null()) {
337 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
338 read_buffer_ = nullptr;
339 DCHECK_GT(result, 0);
340 net::CompletionCallback cb;
341 std::swap(cb, read_callback_);
342 cb.Run(result);
346 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
347 const std::string& base_channel_name)
348 : base_channel_factory_(factory),
349 base_channel_name_(base_channel_name),
350 next_channel_id_(0),
351 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
352 base::Unretained(this)),
353 &reader_),
354 weak_factory_(this) {
357 ChannelMultiplexer::~ChannelMultiplexer() {
358 DCHECK(pending_channels_.empty());
359 STLDeleteValues(&channels_);
361 // Cancel creation of the base channel if it hasn't finished.
362 if (base_channel_factory_)
363 base_channel_factory_->CancelChannelCreation(base_channel_name_);
366 void ChannelMultiplexer::CreateChannel(const std::string& name,
367 const ChannelCreatedCallback& callback) {
368 if (base_channel_.get()) {
369 // Already have |base_channel_|. Create new multiplexed channel
370 // synchronously.
371 callback.Run(GetOrCreateChannel(name)->CreateSocket());
372 } else if (!base_channel_.get() && !base_channel_factory_) {
373 // Fail synchronously if we failed to create |base_channel_|.
374 callback.Run(nullptr);
375 } else {
376 // Still waiting for the |base_channel_|.
377 pending_channels_.push_back(PendingChannel(name, callback));
379 // If this is the first multiplexed channel then create the base channel.
380 if (pending_channels_.size() == 1U) {
381 base_channel_factory_->CreateChannel(
382 base_channel_name_,
383 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
384 base::Unretained(this)));
389 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
390 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
391 it != pending_channels_.end(); ++it) {
392 if (it->name == name) {
393 pending_channels_.erase(it);
394 return;
399 void ChannelMultiplexer::OnBaseChannelReady(
400 scoped_ptr<net::StreamSocket> socket) {
401 base_channel_factory_ = nullptr;
402 base_channel_ = socket.Pass();
404 if (base_channel_.get()) {
405 // Initialize reader and writer.
406 reader_.StartReading(base_channel_.get());
407 writer_.Init(base_channel_.get(),
408 base::Bind(&ChannelMultiplexer::OnWriteFailed,
409 base::Unretained(this)));
412 DoCreatePendingChannels();
415 void ChannelMultiplexer::DoCreatePendingChannels() {
416 if (pending_channels_.empty())
417 return;
419 // Every time this function is called it connects a single channel and posts a
420 // separate task to connect other channels. This is necessary because the
421 // callback may destroy the multiplexer or somehow else modify
422 // |pending_channels_| list (e.g. call CancelChannelCreation()).
423 base::ThreadTaskRunnerHandle::Get()->PostTask(
424 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
425 weak_factory_.GetWeakPtr()));
427 PendingChannel c = pending_channels_.front();
428 pending_channels_.erase(pending_channels_.begin());
429 scoped_ptr<net::StreamSocket> socket;
430 if (base_channel_.get())
431 socket = GetOrCreateChannel(c.name)->CreateSocket();
432 c.callback.Run(socket.Pass());
435 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
436 const std::string& name) {
437 // Check if we already have a channel with the requested name.
438 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
439 if (it != channels_.end())
440 return it->second;
442 // Create a new channel if we haven't found existing one.
443 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
444 ++next_channel_id_;
445 channels_[channel->name()] = channel;
446 return channel;
450 void ChannelMultiplexer::OnWriteFailed(int error) {
451 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
452 it != channels_.end(); ++it) {
453 base::ThreadTaskRunnerHandle::Get()->PostTask(
454 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
455 weak_factory_.GetWeakPtr(), it->second->name()));
459 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
460 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
461 if (it != channels_.end()) {
462 it->second->OnWriteFailed();
466 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
467 const base::Closure& done_task) {
468 DCHECK(packet->has_channel_id());
469 if (!packet->has_channel_id()) {
470 LOG(ERROR) << "Received packet without channel_id.";
471 done_task.Run();
472 return;
475 int receive_id = packet->channel_id();
476 MuxChannel* channel = nullptr;
477 std::map<int, MuxChannel*>::iterator it =
478 channels_by_receive_id_.find(receive_id);
479 if (it != channels_by_receive_id_.end()) {
480 channel = it->second;
481 } else {
482 // This is a new |channel_id| we haven't seen before. Look it up by name.
483 if (!packet->has_channel_name()) {
484 LOG(ERROR) << "Received packet with unknown channel_id and "
485 "without channel_name.";
486 done_task.Run();
487 return;
489 channel = GetOrCreateChannel(packet->channel_name());
490 channel->set_receive_id(receive_id);
491 channels_by_receive_id_[receive_id] = channel;
494 channel->OnIncomingPacket(packet.Pass(), done_task);
497 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
498 const base::Closure& done_task) {
499 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
502 } // namespace protocol
503 } // namespace remoting