Updating trunk VERSION from 2139.0 to 2140.0
[chromium-blink-merge.git] / remoting / protocol / channel_multiplexer.cc
blobe8c195a3e7c24f8e0187e230ee2396e074725d74
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/channel_multiplexer.h"
7 #include <string.h>
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/message_serialization.h"
19 namespace remoting {
20 namespace protocol {
22 namespace {
23 const int kChannelIdUnknown = -1;
24 const int kMaxPacketSize = 1024;
26 class PendingPacket {
27 public:
28 PendingPacket(scoped_ptr<MultiplexPacket> packet,
29 const base::Closure& done_task)
30 : packet(packet.Pass()),
31 done_task(done_task),
32 pos(0U) {
34 ~PendingPacket() {
35 done_task.Run();
38 bool is_empty() { return pos >= packet->data().size(); }
40 int Read(char* buffer, size_t size) {
41 size = std::min(size, packet->data().size() - pos);
42 memcpy(buffer, packet->data().data() + pos, size);
43 pos += size;
44 return size;
47 private:
48 scoped_ptr<MultiplexPacket> packet;
49 base::Closure done_task;
50 size_t pos;
52 DISALLOW_COPY_AND_ASSIGN(PendingPacket);
55 } // namespace
57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
59 struct ChannelMultiplexer::PendingChannel {
60 PendingChannel(const std::string& name,
61 const StreamChannelCallback& callback)
62 : name(name), callback(callback) {
64 std::string name;
65 StreamChannelCallback callback;
68 class ChannelMultiplexer::MuxChannel {
69 public:
70 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71 int send_id);
72 ~MuxChannel();
74 const std::string& name() { return name_; }
75 int receive_id() { return receive_id_; }
76 void set_receive_id(int id) { receive_id_ = id; }
78 // Called by ChannelMultiplexer.
79 scoped_ptr<net::StreamSocket> CreateSocket();
80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81 const base::Closure& done_task);
82 void OnWriteFailed();
84 // Called by MuxSocket.
85 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87 const base::Closure& done_task);
88 int DoRead(net::IOBuffer* buffer, int buffer_len);
90 private:
91 ChannelMultiplexer* multiplexer_;
92 std::string name_;
93 int send_id_;
94 bool id_sent_;
95 int receive_id_;
96 MuxSocket* socket_;
97 std::list<PendingPacket*> pending_packets_;
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103 public base::NonThreadSafe,
104 public base::SupportsWeakPtr<MuxSocket> {
105 public:
106 MuxSocket(MuxChannel* channel);
107 virtual ~MuxSocket();
109 void OnWriteComplete();
110 void OnWriteFailed();
111 void OnPacketReceived();
113 // net::StreamSocket interface.
114 virtual int Read(net::IOBuffer* buffer, int buffer_len,
115 const net::CompletionCallback& callback) OVERRIDE;
116 virtual int Write(net::IOBuffer* buffer, int buffer_len,
117 const net::CompletionCallback& callback) OVERRIDE;
119 virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
120 NOTIMPLEMENTED();
121 return net::ERR_NOT_IMPLEMENTED;
123 virtual int SetSendBufferSize(int32 size) OVERRIDE {
124 NOTIMPLEMENTED();
125 return net::ERR_NOT_IMPLEMENTED;
128 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
129 NOTIMPLEMENTED();
130 return net::ERR_NOT_IMPLEMENTED;
132 virtual void Disconnect() OVERRIDE {
133 NOTIMPLEMENTED();
135 virtual bool IsConnected() const OVERRIDE {
136 NOTIMPLEMENTED();
137 return true;
139 virtual bool IsConnectedAndIdle() const OVERRIDE {
140 NOTIMPLEMENTED();
141 return false;
143 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
144 NOTIMPLEMENTED();
145 return net::ERR_NOT_IMPLEMENTED;
147 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
148 NOTIMPLEMENTED();
149 return net::ERR_NOT_IMPLEMENTED;
151 virtual const net::BoundNetLog& NetLog() const OVERRIDE {
152 NOTIMPLEMENTED();
153 return net_log_;
155 virtual void SetSubresourceSpeculation() OVERRIDE {
156 NOTIMPLEMENTED();
158 virtual void SetOmniboxSpeculation() OVERRIDE {
159 NOTIMPLEMENTED();
161 virtual bool WasEverUsed() const OVERRIDE {
162 return true;
164 virtual bool UsingTCPFastOpen() const OVERRIDE {
165 return false;
167 virtual bool WasNpnNegotiated() const OVERRIDE {
168 return false;
170 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
171 return net::kProtoUnknown;
173 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
174 NOTIMPLEMENTED();
175 return false;
178 private:
179 MuxChannel* channel_;
181 net::CompletionCallback read_callback_;
182 scoped_refptr<net::IOBuffer> read_buffer_;
183 int read_buffer_size_;
185 bool write_pending_;
186 int write_result_;
187 net::CompletionCallback write_callback_;
189 net::BoundNetLog net_log_;
191 DISALLOW_COPY_AND_ASSIGN(MuxSocket);
195 ChannelMultiplexer::MuxChannel::MuxChannel(
196 ChannelMultiplexer* multiplexer,
197 const std::string& name,
198 int send_id)
199 : multiplexer_(multiplexer),
200 name_(name),
201 send_id_(send_id),
202 id_sent_(false),
203 receive_id_(kChannelIdUnknown),
204 socket_(NULL) {
207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
208 // Socket must be destroyed before the channel.
209 DCHECK(!socket_);
210 STLDeleteElements(&pending_packets_);
213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
214 DCHECK(!socket_); // Can't create more than one socket per channel.
215 scoped_ptr<MuxSocket> result(new MuxSocket(this));
216 socket_ = result.get();
217 return result.PassAs<net::StreamSocket>();
220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221 scoped_ptr<MultiplexPacket> packet,
222 const base::Closure& done_task) {
223 DCHECK_EQ(packet->channel_id(), receive_id_);
224 if (packet->data().size() > 0) {
225 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
226 if (socket_) {
227 // Notify the socket that we have more data.
228 socket_->OnPacketReceived();
233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
234 if (socket_)
235 socket_->OnWriteFailed();
238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
239 DCHECK(socket_);
240 socket_ = NULL;
243 bool ChannelMultiplexer::MuxChannel::DoWrite(
244 scoped_ptr<MultiplexPacket> packet,
245 const base::Closure& done_task) {
246 packet->set_channel_id(send_id_);
247 if (!id_sent_) {
248 packet->set_channel_name(name_);
249 id_sent_ = true;
251 return multiplexer_->DoWrite(packet.Pass(), done_task);
254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
255 int buffer_len) {
256 int pos = 0;
257 while (buffer_len > 0 && !pending_packets_.empty()) {
258 DCHECK(!pending_packets_.front()->is_empty());
259 int result = pending_packets_.front()->Read(
260 buffer->data() + pos, buffer_len);
261 DCHECK_LE(result, buffer_len);
262 pos += result;
263 buffer_len -= pos;
264 if (pending_packets_.front()->is_empty()) {
265 delete pending_packets_.front();
266 pending_packets_.erase(pending_packets_.begin());
269 return pos;
272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
273 : channel_(channel),
274 read_buffer_size_(0),
275 write_pending_(false),
276 write_result_(0) {
279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
280 channel_->OnSocketDestroyed();
283 int ChannelMultiplexer::MuxSocket::Read(
284 net::IOBuffer* buffer, int buffer_len,
285 const net::CompletionCallback& callback) {
286 DCHECK(CalledOnValidThread());
287 DCHECK(read_callback_.is_null());
289 int result = channel_->DoRead(buffer, buffer_len);
290 if (result == 0) {
291 read_buffer_ = buffer;
292 read_buffer_size_ = buffer_len;
293 read_callback_ = callback;
294 return net::ERR_IO_PENDING;
296 return result;
299 int ChannelMultiplexer::MuxSocket::Write(
300 net::IOBuffer* buffer, int buffer_len,
301 const net::CompletionCallback& callback) {
302 DCHECK(CalledOnValidThread());
304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
305 size_t size = std::min(kMaxPacketSize, buffer_len);
306 packet->mutable_data()->assign(buffer->data(), size);
308 write_pending_ = true;
309 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
312 if (!result) {
313 // Cannot complete the write, e.g. if the connection has been terminated.
314 return net::ERR_FAILED;
317 // OnWriteComplete() might be called above synchronously.
318 if (write_pending_) {
319 DCHECK(write_callback_.is_null());
320 write_callback_ = callback;
321 write_result_ = size;
322 return net::ERR_IO_PENDING;
325 return size;
328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329 write_pending_ = false;
330 if (!write_callback_.is_null()) {
331 net::CompletionCallback cb;
332 std::swap(cb, write_callback_);
333 cb.Run(write_result_);
337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338 if (!write_callback_.is_null()) {
339 net::CompletionCallback cb;
340 std::swap(cb, write_callback_);
341 cb.Run(net::ERR_FAILED);
345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346 if (!read_callback_.is_null()) {
347 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
348 read_buffer_ = NULL;
349 DCHECK_GT(result, 0);
350 net::CompletionCallback cb;
351 std::swap(cb, read_callback_);
352 cb.Run(result);
356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
357 const std::string& base_channel_name)
358 : base_channel_factory_(factory),
359 base_channel_name_(base_channel_name),
360 next_channel_id_(0),
361 weak_factory_(this) {
364 ChannelMultiplexer::~ChannelMultiplexer() {
365 DCHECK(pending_channels_.empty());
366 STLDeleteValues(&channels_);
368 // Cancel creation of the base channel if it hasn't finished.
369 if (base_channel_factory_)
370 base_channel_factory_->CancelChannelCreation(base_channel_name_);
373 void ChannelMultiplexer::CreateStreamChannel(
374 const std::string& name,
375 const StreamChannelCallback& callback) {
376 if (base_channel_.get()) {
377 // Already have |base_channel_|. Create new multiplexed channel
378 // synchronously.
379 callback.Run(GetOrCreateChannel(name)->CreateSocket());
380 } else if (!base_channel_.get() && !base_channel_factory_) {
381 // Fail synchronously if we failed to create |base_channel_|.
382 callback.Run(scoped_ptr<net::StreamSocket>());
383 } else {
384 // Still waiting for the |base_channel_|.
385 pending_channels_.push_back(PendingChannel(name, callback));
387 // If this is the first multiplexed channel then create the base channel.
388 if (pending_channels_.size() == 1U) {
389 base_channel_factory_->CreateStreamChannel(
390 base_channel_name_,
391 base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
392 base::Unretained(this)));
397 void ChannelMultiplexer::CreateDatagramChannel(
398 const std::string& name,
399 const DatagramChannelCallback& callback) {
400 NOTIMPLEMENTED();
401 callback.Run(scoped_ptr<net::Socket>());
404 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
405 for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
406 it != pending_channels_.end(); ++it) {
407 if (it->name == name) {
408 pending_channels_.erase(it);
409 return;
414 void ChannelMultiplexer::OnBaseChannelReady(
415 scoped_ptr<net::StreamSocket> socket) {
416 base_channel_factory_ = NULL;
417 base_channel_ = socket.Pass();
419 if (base_channel_.get()) {
420 // Initialize reader and writer.
421 reader_.Init(base_channel_.get(),
422 base::Bind(&ChannelMultiplexer::OnIncomingPacket,
423 base::Unretained(this)));
424 writer_.Init(base_channel_.get(),
425 base::Bind(&ChannelMultiplexer::OnWriteFailed,
426 base::Unretained(this)));
429 DoCreatePendingChannels();
432 void ChannelMultiplexer::DoCreatePendingChannels() {
433 if (pending_channels_.empty())
434 return;
436 // Every time this function is called it connects a single channel and posts a
437 // separate task to connect other channels. This is necessary because the
438 // callback may destroy the multiplexer or somehow else modify
439 // |pending_channels_| list (e.g. call CancelChannelCreation()).
440 base::ThreadTaskRunnerHandle::Get()->PostTask(
441 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
442 weak_factory_.GetWeakPtr()));
444 PendingChannel c = pending_channels_.front();
445 pending_channels_.erase(pending_channels_.begin());
446 scoped_ptr<net::StreamSocket> socket;
447 if (base_channel_.get())
448 socket = GetOrCreateChannel(c.name)->CreateSocket();
449 c.callback.Run(socket.Pass());
452 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
453 const std::string& name) {
454 // Check if we already have a channel with the requested name.
455 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
456 if (it != channels_.end())
457 return it->second;
459 // Create a new channel if we haven't found existing one.
460 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
461 ++next_channel_id_;
462 channels_[channel->name()] = channel;
463 return channel;
467 void ChannelMultiplexer::OnWriteFailed(int error) {
468 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
469 it != channels_.end(); ++it) {
470 base::ThreadTaskRunnerHandle::Get()->PostTask(
471 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
472 weak_factory_.GetWeakPtr(), it->second->name()));
476 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
477 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
478 if (it != channels_.end()) {
479 it->second->OnWriteFailed();
483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
484 const base::Closure& done_task) {
485 DCHECK(packet->has_channel_id());
486 if (!packet->has_channel_id()) {
487 LOG(ERROR) << "Received packet without channel_id.";
488 done_task.Run();
489 return;
492 int receive_id = packet->channel_id();
493 MuxChannel* channel = NULL;
494 std::map<int, MuxChannel*>::iterator it =
495 channels_by_receive_id_.find(receive_id);
496 if (it != channels_by_receive_id_.end()) {
497 channel = it->second;
498 } else {
499 // This is a new |channel_id| we haven't seen before. Look it up by name.
500 if (!packet->has_channel_name()) {
501 LOG(ERROR) << "Received packet with unknown channel_id and "
502 "without channel_name.";
503 done_task.Run();
504 return;
506 channel = GetOrCreateChannel(packet->channel_name());
507 channel->set_receive_id(receive_id);
508 channels_by_receive_id_[receive_id] = channel;
511 channel->OnIncomingPacket(packet.Pass(), done_task);
514 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
515 const base::Closure& done_task) {
516 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
519 } // namespace protocol
520 } // namespace remoting