1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "extensions/browser/api/cast_channel/cast_transport.h"
10 #include "base/format_macros.h"
11 #include "base/message_loop/message_loop.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/strings/stringprintf.h"
14 #include "extensions/browser/api/cast_channel/cast_framer.h"
15 #include "extensions/browser/api/cast_channel/cast_message_util.h"
16 #include "extensions/browser/api/cast_channel/logger.h"
17 #include "extensions/browser/api/cast_channel/logger_util.h"
18 #include "extensions/common/api/cast_channel/cast_channel.pb.h"
19 #include "net/base/net_errors.h"
20 #include "net/socket/socket.h"
22 #define VLOG_WITH_CONNECTION(level) \
23 VLOG(level) << "[" << ip_endpoint_.ToString() << ", auth=" << channel_auth_ \
26 namespace extensions
{
28 namespace cast_channel
{
30 CastTransportImpl::CastTransportImpl(net::Socket
* socket
,
32 const net::IPEndPoint
& ip_endpoint
,
33 ChannelAuthType channel_auth
,
34 scoped_refptr
<Logger
> logger
)
37 write_state_(WRITE_STATE_IDLE
),
38 read_state_(READ_STATE_READ
),
39 error_state_(CHANNEL_ERROR_NONE
),
40 channel_id_(channel_id
),
41 ip_endpoint_(ip_endpoint
),
42 channel_auth_(channel_auth
),
46 // Buffer is reused across messages to minimize unnecessary buffer
48 read_buffer_
= new net::GrowableIOBuffer();
49 read_buffer_
->SetCapacity(MessageFramer::MessageHeader::max_message_size());
50 framer_
.reset(new MessageFramer(read_buffer_
));
53 CastTransportImpl::~CastTransportImpl() {
54 DCHECK(CalledOnValidThread());
58 bool CastTransportImpl::IsTerminalWriteState(
59 CastTransportImpl::WriteState write_state
) {
60 return write_state
== WRITE_STATE_ERROR
|| write_state
== WRITE_STATE_IDLE
;
63 bool CastTransportImpl::IsTerminalReadState(
64 CastTransportImpl::ReadState read_state
) {
65 return read_state
== READ_STATE_ERROR
;
69 proto::ReadState
CastTransportImpl::ReadStateToProto(
70 CastTransportImpl::ReadState state
) {
72 case CastTransportImpl::READ_STATE_UNKNOWN
:
73 return proto::READ_STATE_UNKNOWN
;
74 case CastTransportImpl::READ_STATE_READ
:
75 return proto::READ_STATE_READ
;
76 case CastTransportImpl::READ_STATE_READ_COMPLETE
:
77 return proto::READ_STATE_READ_COMPLETE
;
78 case CastTransportImpl::READ_STATE_DO_CALLBACK
:
79 return proto::READ_STATE_DO_CALLBACK
;
80 case CastTransportImpl::READ_STATE_HANDLE_ERROR
:
81 return proto::READ_STATE_HANDLE_ERROR
;
82 case CastTransportImpl::READ_STATE_ERROR
:
83 return proto::READ_STATE_ERROR
;
86 return proto::READ_STATE_UNKNOWN
;
91 proto::WriteState
CastTransportImpl::WriteStateToProto(
92 CastTransportImpl::WriteState state
) {
94 case CastTransportImpl::WRITE_STATE_IDLE
:
95 return proto::WRITE_STATE_IDLE
;
96 case CastTransportImpl::WRITE_STATE_UNKNOWN
:
97 return proto::WRITE_STATE_UNKNOWN
;
98 case CastTransportImpl::WRITE_STATE_WRITE
:
99 return proto::WRITE_STATE_WRITE
;
100 case CastTransportImpl::WRITE_STATE_WRITE_COMPLETE
:
101 return proto::WRITE_STATE_WRITE_COMPLETE
;
102 case CastTransportImpl::WRITE_STATE_DO_CALLBACK
:
103 return proto::WRITE_STATE_DO_CALLBACK
;
104 case CastTransportImpl::WRITE_STATE_HANDLE_ERROR
:
105 return proto::WRITE_STATE_HANDLE_ERROR
;
106 case CastTransportImpl::WRITE_STATE_ERROR
:
107 return proto::WRITE_STATE_ERROR
;
110 return proto::WRITE_STATE_UNKNOWN
;
115 proto::ErrorState
CastTransportImpl::ErrorStateToProto(ChannelError state
) {
117 case CHANNEL_ERROR_NONE
:
118 return proto::CHANNEL_ERROR_NONE
;
119 case CHANNEL_ERROR_CHANNEL_NOT_OPEN
:
120 return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN
;
121 case CHANNEL_ERROR_AUTHENTICATION_ERROR
:
122 return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR
;
123 case CHANNEL_ERROR_CONNECT_ERROR
:
124 return proto::CHANNEL_ERROR_CONNECT_ERROR
;
125 case CHANNEL_ERROR_SOCKET_ERROR
:
126 return proto::CHANNEL_ERROR_SOCKET_ERROR
;
127 case CHANNEL_ERROR_TRANSPORT_ERROR
:
128 return proto::CHANNEL_ERROR_TRANSPORT_ERROR
;
129 case CHANNEL_ERROR_INVALID_MESSAGE
:
130 return proto::CHANNEL_ERROR_INVALID_MESSAGE
;
131 case CHANNEL_ERROR_INVALID_CHANNEL_ID
:
132 return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID
;
133 case CHANNEL_ERROR_CONNECT_TIMEOUT
:
134 return proto::CHANNEL_ERROR_CONNECT_TIMEOUT
;
135 case CHANNEL_ERROR_UNKNOWN
:
136 return proto::CHANNEL_ERROR_UNKNOWN
;
139 return proto::CHANNEL_ERROR_NONE
;
143 void CastTransportImpl::SetReadDelegate(scoped_ptr
<Delegate
> delegate
) {
144 DCHECK(CalledOnValidThread());
146 delegate_
= delegate
.Pass();
152 void CastTransportImpl::FlushWriteQueue() {
153 for (; !write_queue_
.empty(); write_queue_
.pop()) {
154 net::CompletionCallback
& callback
= write_queue_
.front().callback
;
155 base::MessageLoop::current()->PostTask(
156 FROM_HERE
, base::Bind(callback
, net::ERR_FAILED
));
161 void CastTransportImpl::SendMessage(const CastMessage
& message
,
162 const net::CompletionCallback
& callback
) {
163 DCHECK(CalledOnValidThread());
164 std::string serialized_message
;
165 if (!MessageFramer::Serialize(message
, &serialized_message
)) {
166 logger_
->LogSocketEventForMessage(channel_id_
, proto::SEND_MESSAGE_FAILED
,
167 message
.namespace_(),
168 "Error when serializing message.");
169 base::MessageLoop::current()->PostTask(
170 FROM_HERE
, base::Bind(callback
, net::ERR_FAILED
));
173 WriteRequest
write_request(
174 message
.namespace_(), serialized_message
, callback
);
176 write_queue_
.push(write_request
);
177 logger_
->LogSocketEventForMessage(
178 channel_id_
, proto::MESSAGE_ENQUEUED
, message
.namespace_(),
179 base::StringPrintf("Queue size: %" PRIuS
, write_queue_
.size()));
180 if (write_state_
== WRITE_STATE_IDLE
) {
181 SetWriteState(WRITE_STATE_WRITE
);
182 OnWriteResult(net::OK
);
186 CastTransportImpl::WriteRequest::WriteRequest(
187 const std::string
& namespace_
,
188 const std::string
& payload
,
189 const net::CompletionCallback
& callback
)
190 : message_namespace(namespace_
), callback(callback
) {
191 VLOG(2) << "WriteRequest size: " << payload
.size();
192 io_buffer
= new net::DrainableIOBuffer(new net::StringIOBuffer(payload
),
196 CastTransportImpl::WriteRequest::~WriteRequest() {
199 void CastTransportImpl::SetReadState(ReadState read_state
) {
200 if (read_state_
!= read_state
) {
201 read_state_
= read_state
;
202 logger_
->LogSocketReadState(channel_id_
, ReadStateToProto(read_state_
));
206 void CastTransportImpl::SetWriteState(WriteState write_state
) {
207 if (write_state_
!= write_state
) {
208 write_state_
= write_state
;
209 logger_
->LogSocketWriteState(channel_id_
, WriteStateToProto(write_state_
));
213 void CastTransportImpl::SetErrorState(ChannelError error_state
) {
214 VLOG_WITH_CONNECTION(2) << "SetErrorState: " << error_state
;
215 error_state_
= error_state
;
218 void CastTransportImpl::OnWriteResult(int result
) {
219 DCHECK(CalledOnValidThread());
220 DCHECK_NE(WRITE_STATE_IDLE
, write_state_
);
221 if (write_queue_
.empty()) {
222 SetWriteState(WRITE_STATE_IDLE
);
226 // Network operations can either finish synchronously or asynchronously.
227 // This method executes the state machine transitions in a loop so that
228 // write state transitions happen even when network operations finish
232 VLOG_WITH_CONNECTION(2) << "OnWriteResult (state=" << write_state_
<< ", "
233 << "result=" << rv
<< ", "
234 << "queue size=" << write_queue_
.size() << ")";
236 WriteState state
= write_state_
;
237 write_state_
= WRITE_STATE_UNKNOWN
;
239 case WRITE_STATE_WRITE
:
242 case WRITE_STATE_WRITE_COMPLETE
:
243 rv
= DoWriteComplete(rv
);
245 case WRITE_STATE_DO_CALLBACK
:
246 rv
= DoWriteCallback();
248 case WRITE_STATE_HANDLE_ERROR
:
249 rv
= DoWriteHandleError(rv
);
250 DCHECK_EQ(WRITE_STATE_ERROR
, write_state_
);
253 NOTREACHED() << "Unknown state in write state machine: " << state
;
254 SetWriteState(WRITE_STATE_ERROR
);
255 SetErrorState(CHANNEL_ERROR_UNKNOWN
);
256 rv
= net::ERR_FAILED
;
259 } while (rv
!= net::ERR_IO_PENDING
&& !IsTerminalWriteState(write_state_
));
261 if (IsTerminalWriteState(write_state_
)) {
262 logger_
->LogSocketWriteState(channel_id_
, WriteStateToProto(write_state_
));
264 if (write_state_
== WRITE_STATE_ERROR
) {
266 DCHECK_NE(CHANNEL_ERROR_NONE
, error_state_
);
267 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
268 delegate_
->OnError(error_state_
);
273 int CastTransportImpl::DoWrite() {
274 DCHECK(!write_queue_
.empty());
275 WriteRequest
& request
= write_queue_
.front();
277 VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
278 << request
.io_buffer
->size() << " bytes_written "
279 << request
.io_buffer
->BytesConsumed();
281 SetWriteState(WRITE_STATE_WRITE_COMPLETE
);
283 int rv
= socket_
->Write(
284 request
.io_buffer
.get(), request
.io_buffer
->BytesRemaining(),
285 base::Bind(&CastTransportImpl::OnWriteResult
, base::Unretained(this)));
289 int CastTransportImpl::DoWriteComplete(int result
) {
290 VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result
;
291 DCHECK(!write_queue_
.empty());
292 logger_
->LogSocketEventWithRv(channel_id_
, proto::SOCKET_WRITE
, result
);
293 if (result
<= 0) { // NOTE that 0 also indicates an error
294 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR
);
295 SetWriteState(WRITE_STATE_HANDLE_ERROR
);
296 return result
== 0 ? net::ERR_FAILED
: result
;
299 // Some bytes were successfully written
300 WriteRequest
& request
= write_queue_
.front();
301 scoped_refptr
<net::DrainableIOBuffer
> io_buffer
= request
.io_buffer
;
302 io_buffer
->DidConsume(result
);
303 if (io_buffer
->BytesRemaining() == 0) { // Message fully sent
304 SetWriteState(WRITE_STATE_DO_CALLBACK
);
306 SetWriteState(WRITE_STATE_WRITE
);
312 int CastTransportImpl::DoWriteCallback() {
313 VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
314 DCHECK(!write_queue_
.empty());
316 WriteRequest
& request
= write_queue_
.front();
317 int bytes_consumed
= request
.io_buffer
->BytesConsumed();
318 logger_
->LogSocketEventForMessage(
319 channel_id_
, proto::MESSAGE_WRITTEN
, request
.message_namespace
,
320 base::StringPrintf("Bytes: %d", bytes_consumed
));
321 base::MessageLoop::current()->PostTask(FROM_HERE
,
322 base::Bind(&base::DoNothing
));
323 base::MessageLoop::current()->PostTask(FROM_HERE
,
324 base::Bind(request
.callback
, net::OK
));
327 if (write_queue_
.empty()) {
328 SetWriteState(WRITE_STATE_IDLE
);
330 SetWriteState(WRITE_STATE_WRITE
);
336 int CastTransportImpl::DoWriteHandleError(int result
) {
337 VLOG_WITH_CONNECTION(2) << "DoWriteHandleError result=" << result
;
338 DCHECK_NE(CHANNEL_ERROR_NONE
, error_state_
);
339 DCHECK_LT(result
, 0);
340 SetWriteState(WRITE_STATE_ERROR
);
341 return net::ERR_FAILED
;
344 void CastTransportImpl::Start() {
345 DCHECK(CalledOnValidThread());
347 DCHECK_EQ(READ_STATE_READ
, read_state_
);
348 DCHECK(delegate_
) << "Read delegate must be set prior to calling Start()";
351 SetReadState(READ_STATE_READ
);
353 // Start the read state machine.
354 OnReadResult(net::OK
);
357 void CastTransportImpl::OnReadResult(int result
) {
358 DCHECK(CalledOnValidThread());
359 // Network operations can either finish synchronously or asynchronously.
360 // This method executes the state machine transitions in a loop so that
361 // write state transitions happen even when network operations finish
365 VLOG_WITH_CONNECTION(2) << "OnReadResult(state=" << read_state_
366 << ", result=" << rv
<< ")";
367 ReadState state
= read_state_
;
368 read_state_
= READ_STATE_UNKNOWN
;
371 case READ_STATE_READ
:
374 case READ_STATE_READ_COMPLETE
:
375 rv
= DoReadComplete(rv
);
377 case READ_STATE_DO_CALLBACK
:
378 rv
= DoReadCallback();
380 case READ_STATE_HANDLE_ERROR
:
381 rv
= DoReadHandleError(rv
);
382 DCHECK_EQ(read_state_
, READ_STATE_ERROR
);
385 NOTREACHED() << "Unknown state in read state machine: " << state
;
386 SetReadState(READ_STATE_ERROR
);
387 SetErrorState(CHANNEL_ERROR_UNKNOWN
);
388 rv
= net::ERR_FAILED
;
391 } while (rv
!= net::ERR_IO_PENDING
&& !IsTerminalReadState(read_state_
));
393 if (IsTerminalReadState(read_state_
)) {
394 DCHECK_EQ(READ_STATE_ERROR
, read_state_
);
395 logger_
->LogSocketReadState(channel_id_
, ReadStateToProto(read_state_
));
396 VLOG_WITH_CONNECTION(2) << "Sending OnError().";
397 delegate_
->OnError(error_state_
);
401 int CastTransportImpl::DoRead() {
402 VLOG_WITH_CONNECTION(2) << "DoRead";
403 SetReadState(READ_STATE_READ_COMPLETE
);
405 // Determine how many bytes need to be read.
406 size_t num_bytes_to_read
= framer_
->BytesRequested();
407 DCHECK_GT(num_bytes_to_read
, 0u);
409 // Read up to num_bytes_to_read into |current_read_buffer_|.
410 return socket_
->Read(
411 read_buffer_
.get(), base::checked_cast
<uint32
>(num_bytes_to_read
),
412 base::Bind(&CastTransportImpl::OnReadResult
, base::Unretained(this)));
415 int CastTransportImpl::DoReadComplete(int result
) {
416 VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result
;
417 logger_
->LogSocketEventWithRv(channel_id_
, proto::SOCKET_READ
, result
);
419 VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket.";
420 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR
);
421 SetReadState(READ_STATE_HANDLE_ERROR
);
422 return result
== 0 ? net::ERR_FAILED
: result
;
426 DCHECK(!current_message_
);
427 ChannelError framing_error
;
428 current_message_
= framer_
->Ingest(result
, &message_size
, &framing_error
);
429 if (current_message_
.get() && (framing_error
== CHANNEL_ERROR_NONE
)) {
430 DCHECK_GT(message_size
, static_cast<size_t>(0));
431 logger_
->LogSocketEventForMessage(
432 channel_id_
, proto::MESSAGE_READ
, current_message_
->namespace_(),
433 base::StringPrintf("Message size: %u",
434 static_cast<uint32
>(message_size
)));
435 SetReadState(READ_STATE_DO_CALLBACK
);
436 } else if (framing_error
!= CHANNEL_ERROR_NONE
) {
437 DCHECK(!current_message_
);
438 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE
);
439 SetReadState(READ_STATE_HANDLE_ERROR
);
441 DCHECK(!current_message_
);
442 SetReadState(READ_STATE_READ
);
447 int CastTransportImpl::DoReadCallback() {
448 VLOG_WITH_CONNECTION(2) << "DoReadCallback";
449 if (!IsCastMessageValid(*current_message_
)) {
450 SetReadState(READ_STATE_HANDLE_ERROR
);
451 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE
);
452 return net::ERR_INVALID_RESPONSE
;
454 SetReadState(READ_STATE_READ
);
455 delegate_
->OnMessage(*current_message_
);
456 current_message_
.reset();
460 int CastTransportImpl::DoReadHandleError(int result
) {
461 VLOG_WITH_CONNECTION(2) << "DoReadHandleError";
462 DCHECK_NE(CHANNEL_ERROR_NONE
, error_state_
);
463 DCHECK_LE(result
, 0);
464 SetReadState(READ_STATE_ERROR
);
465 return net::ERR_FAILED
;
468 } // namespace cast_channel
470 } // namespace extensions