Roll src/third_party/WebKit a3b4a2e:7441784 (svn 202551:202552)
[chromium-blink-merge.git] / extensions / browser / api / cast_channel / cast_transport.cc
blob0ce1db01e549656b479d34480d262c614287495d
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "extensions/browser/api/cast_channel/cast_transport.h"
7 #include <string>
9 #include "base/bind.h"
10 #include "base/format_macros.h"
11 #include "base/message_loop/message_loop.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/strings/stringprintf.h"
14 #include "extensions/browser/api/cast_channel/cast_framer.h"
15 #include "extensions/browser/api/cast_channel/cast_message_util.h"
16 #include "extensions/browser/api/cast_channel/logger.h"
17 #include "extensions/browser/api/cast_channel/logger_util.h"
18 #include "extensions/common/api/cast_channel/cast_channel.pb.h"
19 #include "net/base/net_errors.h"
20 #include "net/socket/socket.h"
22 #define VLOG_WITH_CONNECTION(level) \
23 VLOG(level) << "[" << ip_endpoint_.ToString() << ", auth=" << channel_auth_ \
24 << "] "
26 namespace extensions {
27 namespace api {
28 namespace cast_channel {
30 CastTransportImpl::CastTransportImpl(net::Socket* socket,
31 int channel_id,
32 const net::IPEndPoint& ip_endpoint,
33 ChannelAuthType channel_auth,
34 scoped_refptr<Logger> logger)
35 : started_(false),
36 socket_(socket),
37 write_state_(WRITE_STATE_IDLE),
38 read_state_(READ_STATE_READ),
39 error_state_(CHANNEL_ERROR_NONE),
40 channel_id_(channel_id),
41 ip_endpoint_(ip_endpoint),
42 channel_auth_(channel_auth),
43 logger_(logger) {
44 DCHECK(socket);
46 // Buffer is reused across messages to minimize unnecessary buffer
47 // [re]allocations.
48 read_buffer_ = new net::GrowableIOBuffer();
49 read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
50 framer_.reset(new MessageFramer(read_buffer_));
53 CastTransportImpl::~CastTransportImpl() {
54 DCHECK(CalledOnValidThread());
55 FlushWriteQueue();
58 bool CastTransportImpl::IsTerminalWriteState(
59 CastTransportImpl::WriteState write_state) {
60 return write_state == WRITE_STATE_ERROR || write_state == WRITE_STATE_IDLE;
63 bool CastTransportImpl::IsTerminalReadState(
64 CastTransportImpl::ReadState read_state) {
65 return read_state == READ_STATE_ERROR;
68 // static
69 proto::ReadState CastTransportImpl::ReadStateToProto(
70 CastTransportImpl::ReadState state) {
71 switch (state) {
72 case CastTransportImpl::READ_STATE_UNKNOWN:
73 return proto::READ_STATE_UNKNOWN;
74 case CastTransportImpl::READ_STATE_READ:
75 return proto::READ_STATE_READ;
76 case CastTransportImpl::READ_STATE_READ_COMPLETE:
77 return proto::READ_STATE_READ_COMPLETE;
78 case CastTransportImpl::READ_STATE_DO_CALLBACK:
79 return proto::READ_STATE_DO_CALLBACK;
80 case CastTransportImpl::READ_STATE_HANDLE_ERROR:
81 return proto::READ_STATE_HANDLE_ERROR;
82 case CastTransportImpl::READ_STATE_ERROR:
83 return proto::READ_STATE_ERROR;
84 default:
85 NOTREACHED();
86 return proto::READ_STATE_UNKNOWN;
90 // static
91 proto::WriteState CastTransportImpl::WriteStateToProto(
92 CastTransportImpl::WriteState state) {
93 switch (state) {
94 case CastTransportImpl::WRITE_STATE_IDLE:
95 return proto::WRITE_STATE_IDLE;
96 case CastTransportImpl::WRITE_STATE_UNKNOWN:
97 return proto::WRITE_STATE_UNKNOWN;
98 case CastTransportImpl::WRITE_STATE_WRITE:
99 return proto::WRITE_STATE_WRITE;
100 case CastTransportImpl::WRITE_STATE_WRITE_COMPLETE:
101 return proto::WRITE_STATE_WRITE_COMPLETE;
102 case CastTransportImpl::WRITE_STATE_DO_CALLBACK:
103 return proto::WRITE_STATE_DO_CALLBACK;
104 case CastTransportImpl::WRITE_STATE_HANDLE_ERROR:
105 return proto::WRITE_STATE_HANDLE_ERROR;
106 case CastTransportImpl::WRITE_STATE_ERROR:
107 return proto::WRITE_STATE_ERROR;
108 default:
109 NOTREACHED();
110 return proto::WRITE_STATE_UNKNOWN;
114 // static
115 proto::ErrorState CastTransportImpl::ErrorStateToProto(ChannelError state) {
116 switch (state) {
117 case CHANNEL_ERROR_NONE:
118 return proto::CHANNEL_ERROR_NONE;
119 case CHANNEL_ERROR_CHANNEL_NOT_OPEN:
120 return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN;
121 case CHANNEL_ERROR_AUTHENTICATION_ERROR:
122 return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR;
123 case CHANNEL_ERROR_CONNECT_ERROR:
124 return proto::CHANNEL_ERROR_CONNECT_ERROR;
125 case CHANNEL_ERROR_SOCKET_ERROR:
126 return proto::CHANNEL_ERROR_SOCKET_ERROR;
127 case CHANNEL_ERROR_TRANSPORT_ERROR:
128 return proto::CHANNEL_ERROR_TRANSPORT_ERROR;
129 case CHANNEL_ERROR_INVALID_MESSAGE:
130 return proto::CHANNEL_ERROR_INVALID_MESSAGE;
131 case CHANNEL_ERROR_INVALID_CHANNEL_ID:
132 return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID;
133 case CHANNEL_ERROR_CONNECT_TIMEOUT:
134 return proto::CHANNEL_ERROR_CONNECT_TIMEOUT;
135 case CHANNEL_ERROR_UNKNOWN:
136 return proto::CHANNEL_ERROR_UNKNOWN;
137 default:
138 NOTREACHED();
139 return proto::CHANNEL_ERROR_NONE;
143 void CastTransportImpl::SetReadDelegate(scoped_ptr<Delegate> delegate) {
144 DCHECK(CalledOnValidThread());
145 DCHECK(delegate);
146 delegate_ = delegate.Pass();
147 if (started_) {
148 delegate_->Start();
152 void CastTransportImpl::FlushWriteQueue() {
153 for (; !write_queue_.empty(); write_queue_.pop()) {
154 net::CompletionCallback& callback = write_queue_.front().callback;
155 base::MessageLoop::current()->PostTask(
156 FROM_HERE, base::Bind(callback, net::ERR_FAILED));
157 callback.Reset();
161 void CastTransportImpl::SendMessage(const CastMessage& message,
162 const net::CompletionCallback& callback) {
163 DCHECK(CalledOnValidThread());
164 std::string serialized_message;
165 if (!MessageFramer::Serialize(message, &serialized_message)) {
166 logger_->LogSocketEventForMessage(channel_id_, proto::SEND_MESSAGE_FAILED,
167 message.namespace_(),
168 "Error when serializing message.");
169 base::MessageLoop::current()->PostTask(
170 FROM_HERE, base::Bind(callback, net::ERR_FAILED));
171 return;
173 WriteRequest write_request(
174 message.namespace_(), serialized_message, callback);
176 write_queue_.push(write_request);
177 logger_->LogSocketEventForMessage(
178 channel_id_, proto::MESSAGE_ENQUEUED, message.namespace_(),
179 base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
180 if (write_state_ == WRITE_STATE_IDLE) {
181 SetWriteState(WRITE_STATE_WRITE);
182 OnWriteResult(net::OK);
186 CastTransportImpl::WriteRequest::WriteRequest(
187 const std::string& namespace_,
188 const std::string& payload,
189 const net::CompletionCallback& callback)
190 : message_namespace(namespace_), callback(callback) {
191 VLOG(2) << "WriteRequest size: " << payload.size();
192 io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(payload),
193 payload.size());
196 CastTransportImpl::WriteRequest::~WriteRequest() {
199 void CastTransportImpl::SetReadState(ReadState read_state) {
200 if (read_state_ != read_state) {
201 read_state_ = read_state;
202 logger_->LogSocketReadState(channel_id_, ReadStateToProto(read_state_));
206 void CastTransportImpl::SetWriteState(WriteState write_state) {
207 if (write_state_ != write_state) {
208 write_state_ = write_state;
209 logger_->LogSocketWriteState(channel_id_, WriteStateToProto(write_state_));
213 void CastTransportImpl::SetErrorState(ChannelError error_state) {
214 VLOG_WITH_CONNECTION(2) << "SetErrorState: " << error_state;
215 error_state_ = error_state;
218 void CastTransportImpl::OnWriteResult(int result) {
219 DCHECK(CalledOnValidThread());
220 DCHECK_NE(WRITE_STATE_IDLE, write_state_);
221 if (write_queue_.empty()) {
222 SetWriteState(WRITE_STATE_IDLE);
223 return;
226 // Network operations can either finish synchronously or asynchronously.
227 // This method executes the state machine transitions in a loop so that
228 // write state transitions happen even when network operations finish
229 // synchronously.
230 int rv = result;
231 do {
232 VLOG_WITH_CONNECTION(2) << "OnWriteResult (state=" << write_state_ << ", "
233 << "result=" << rv << ", "
234 << "queue size=" << write_queue_.size() << ")";
236 WriteState state = write_state_;
237 write_state_ = WRITE_STATE_UNKNOWN;
238 switch (state) {
239 case WRITE_STATE_WRITE:
240 rv = DoWrite();
241 break;
242 case WRITE_STATE_WRITE_COMPLETE:
243 rv = DoWriteComplete(rv);
244 break;
245 case WRITE_STATE_DO_CALLBACK:
246 rv = DoWriteCallback();
247 break;
248 case WRITE_STATE_HANDLE_ERROR:
249 rv = DoWriteHandleError(rv);
250 DCHECK_EQ(WRITE_STATE_ERROR, write_state_);
251 break;
252 default:
253 NOTREACHED() << "Unknown state in write state machine: " << state;
254 SetWriteState(WRITE_STATE_ERROR);
255 SetErrorState(CHANNEL_ERROR_UNKNOWN);
256 rv = net::ERR_FAILED;
257 break;
259 } while (rv != net::ERR_IO_PENDING && !IsTerminalWriteState(write_state_));
261 if (IsTerminalWriteState(write_state_)) {
262 logger_->LogSocketWriteState(channel_id_, WriteStateToProto(write_state_));
264 if (write_state_ == WRITE_STATE_ERROR) {
265 FlushWriteQueue();
266 DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
267 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
268 delegate_->OnError(error_state_);
273 int CastTransportImpl::DoWrite() {
274 DCHECK(!write_queue_.empty());
275 WriteRequest& request = write_queue_.front();
277 VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
278 << request.io_buffer->size() << " bytes_written "
279 << request.io_buffer->BytesConsumed();
281 SetWriteState(WRITE_STATE_WRITE_COMPLETE);
283 int rv = socket_->Write(
284 request.io_buffer.get(), request.io_buffer->BytesRemaining(),
285 base::Bind(&CastTransportImpl::OnWriteResult, base::Unretained(this)));
286 return rv;
289 int CastTransportImpl::DoWriteComplete(int result) {
290 VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
291 DCHECK(!write_queue_.empty());
292 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_WRITE, result);
293 if (result <= 0) { // NOTE that 0 also indicates an error
294 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
295 SetWriteState(WRITE_STATE_HANDLE_ERROR);
296 return result == 0 ? net::ERR_FAILED : result;
299 // Some bytes were successfully written
300 WriteRequest& request = write_queue_.front();
301 scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
302 io_buffer->DidConsume(result);
303 if (io_buffer->BytesRemaining() == 0) { // Message fully sent
304 SetWriteState(WRITE_STATE_DO_CALLBACK);
305 } else {
306 SetWriteState(WRITE_STATE_WRITE);
309 return net::OK;
312 int CastTransportImpl::DoWriteCallback() {
313 VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
314 DCHECK(!write_queue_.empty());
316 WriteRequest& request = write_queue_.front();
317 int bytes_consumed = request.io_buffer->BytesConsumed();
318 logger_->LogSocketEventForMessage(
319 channel_id_, proto::MESSAGE_WRITTEN, request.message_namespace,
320 base::StringPrintf("Bytes: %d", bytes_consumed));
321 base::MessageLoop::current()->PostTask(FROM_HERE,
322 base::Bind(&base::DoNothing));
323 base::MessageLoop::current()->PostTask(FROM_HERE,
324 base::Bind(request.callback, net::OK));
326 write_queue_.pop();
327 if (write_queue_.empty()) {
328 SetWriteState(WRITE_STATE_IDLE);
329 } else {
330 SetWriteState(WRITE_STATE_WRITE);
333 return net::OK;
336 int CastTransportImpl::DoWriteHandleError(int result) {
337 VLOG_WITH_CONNECTION(2) << "DoWriteHandleError result=" << result;
338 DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
339 DCHECK_LT(result, 0);
340 SetWriteState(WRITE_STATE_ERROR);
341 return net::ERR_FAILED;
344 void CastTransportImpl::Start() {
345 DCHECK(CalledOnValidThread());
346 DCHECK(!started_);
347 DCHECK_EQ(READ_STATE_READ, read_state_);
348 DCHECK(delegate_) << "Read delegate must be set prior to calling Start()";
349 started_ = true;
350 delegate_->Start();
351 SetReadState(READ_STATE_READ);
353 // Start the read state machine.
354 OnReadResult(net::OK);
357 void CastTransportImpl::OnReadResult(int result) {
358 DCHECK(CalledOnValidThread());
359 // Network operations can either finish synchronously or asynchronously.
360 // This method executes the state machine transitions in a loop so that
361 // write state transitions happen even when network operations finish
362 // synchronously.
363 int rv = result;
364 do {
365 VLOG_WITH_CONNECTION(2) << "OnReadResult(state=" << read_state_
366 << ", result=" << rv << ")";
367 ReadState state = read_state_;
368 read_state_ = READ_STATE_UNKNOWN;
370 switch (state) {
371 case READ_STATE_READ:
372 rv = DoRead();
373 break;
374 case READ_STATE_READ_COMPLETE:
375 rv = DoReadComplete(rv);
376 break;
377 case READ_STATE_DO_CALLBACK:
378 rv = DoReadCallback();
379 break;
380 case READ_STATE_HANDLE_ERROR:
381 rv = DoReadHandleError(rv);
382 DCHECK_EQ(read_state_, READ_STATE_ERROR);
383 break;
384 default:
385 NOTREACHED() << "Unknown state in read state machine: " << state;
386 SetReadState(READ_STATE_ERROR);
387 SetErrorState(CHANNEL_ERROR_UNKNOWN);
388 rv = net::ERR_FAILED;
389 break;
391 } while (rv != net::ERR_IO_PENDING && !IsTerminalReadState(read_state_));
393 if (IsTerminalReadState(read_state_)) {
394 DCHECK_EQ(READ_STATE_ERROR, read_state_);
395 logger_->LogSocketReadState(channel_id_, ReadStateToProto(read_state_));
396 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
397 delegate_->OnError(error_state_);
401 int CastTransportImpl::DoRead() {
402 VLOG_WITH_CONNECTION(2) << "DoRead";
403 SetReadState(READ_STATE_READ_COMPLETE);
405 // Determine how many bytes need to be read.
406 size_t num_bytes_to_read = framer_->BytesRequested();
407 DCHECK_GT(num_bytes_to_read, 0u);
409 // Read up to num_bytes_to_read into |current_read_buffer_|.
410 return socket_->Read(
411 read_buffer_.get(), base::checked_cast<uint32>(num_bytes_to_read),
412 base::Bind(&CastTransportImpl::OnReadResult, base::Unretained(this)));
415 int CastTransportImpl::DoReadComplete(int result) {
416 VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
417 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, result);
418 if (result <= 0) {
419 VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket.";
420 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
421 SetReadState(READ_STATE_HANDLE_ERROR);
422 return result == 0 ? net::ERR_FAILED : result;
425 size_t message_size;
426 DCHECK(!current_message_);
427 ChannelError framing_error;
428 current_message_ = framer_->Ingest(result, &message_size, &framing_error);
429 if (current_message_.get() && (framing_error == CHANNEL_ERROR_NONE)) {
430 DCHECK_GT(message_size, static_cast<size_t>(0));
431 logger_->LogSocketEventForMessage(
432 channel_id_, proto::MESSAGE_READ, current_message_->namespace_(),
433 base::StringPrintf("Message size: %u",
434 static_cast<uint32>(message_size)));
435 SetReadState(READ_STATE_DO_CALLBACK);
436 } else if (framing_error != CHANNEL_ERROR_NONE) {
437 DCHECK(!current_message_);
438 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
439 SetReadState(READ_STATE_HANDLE_ERROR);
440 } else {
441 DCHECK(!current_message_);
442 SetReadState(READ_STATE_READ);
444 return net::OK;
447 int CastTransportImpl::DoReadCallback() {
448 VLOG_WITH_CONNECTION(2) << "DoReadCallback";
449 if (!IsCastMessageValid(*current_message_)) {
450 SetReadState(READ_STATE_HANDLE_ERROR);
451 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
452 return net::ERR_INVALID_RESPONSE;
454 SetReadState(READ_STATE_READ);
455 delegate_->OnMessage(*current_message_);
456 current_message_.reset();
457 return net::OK;
460 int CastTransportImpl::DoReadHandleError(int result) {
461 VLOG_WITH_CONNECTION(2) << "DoReadHandleError";
462 DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
463 DCHECK_LE(result, 0);
464 SetReadState(READ_STATE_ERROR);
465 return net::ERR_FAILED;
468 } // namespace cast_channel
469 } // namespace api
470 } // namespace extensions