1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/location.h"
8 #include "base/thread_task_runner_handle.h"
9 #include "google/protobuf/io/coded_stream.h"
10 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
11 #include "google_apis/gcm/base/mcs_util.h"
12 #include "google_apis/gcm/base/socket_stream.h"
13 #include "google_apis/gcm/protocol/mcs.pb.h"
14 #include "net/base/net_errors.h"
15 #include "net/socket/stream_socket.h"
17 using namespace google::protobuf::io
;
23 // # of bytes a MCS version packet consumes.
24 const int kVersionPacketLen
= 1;
25 // # of bytes a tag packet consumes.
26 const int kTagPacketLen
= 1;
27 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
28 // (the msb in each byte is reserved for denoting whether more bytes follow).
29 // Although the protocol only allows for 4KiB payloads currently, and the socket
30 // stream buffer is only of size 8KiB, it's possible for certain applications to
31 // have larger message sizes. When payload is larger than 4KiB, an temporary
32 // in-memory buffer is used instead of the normal in-place socket stream buffer.
33 const int kSizePacketLenMin
= 1;
34 const int kSizePacketLenMax
= 5;
36 // The normal limit for a data packet is 4KiB. Any data packet with a size
37 // larger than this uses the temporary in-memory buffer,
38 const int kDefaultDataPacketLimit
= 1024 * 4;
40 // The current MCS protocol version.
41 const int kMCSVersion
= 41;
45 ConnectionHandlerImpl::ConnectionHandlerImpl(
46 base::TimeDelta read_timeout
,
47 const ProtoReceivedCallback
& read_callback
,
48 const ProtoSentCallback
& write_callback
,
49 const ConnectionChangedCallback
& connection_callback
)
50 : read_timeout_(read_timeout
),
52 handshake_complete_(false),
55 read_callback_(read_callback
),
56 write_callback_(write_callback
),
57 connection_callback_(connection_callback
),
58 size_packet_so_far_(0),
59 weak_ptr_factory_(this) {
62 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
65 void ConnectionHandlerImpl::Init(
66 const mcs_proto::LoginRequest
& login_request
,
67 net::StreamSocket
* socket
) {
68 DCHECK(!read_callback_
.is_null());
69 DCHECK(!write_callback_
.is_null());
70 DCHECK(!connection_callback_
.is_null());
72 // Invalidate any previously outstanding reads.
73 weak_ptr_factory_
.InvalidateWeakPtrs();
75 handshake_complete_
= false;
79 input_stream_
.reset(new SocketInputStream(socket_
));
80 output_stream_
.reset(new SocketOutputStream(socket_
));
85 void ConnectionHandlerImpl::Reset() {
89 bool ConnectionHandlerImpl::CanSendMessage() const {
90 return handshake_complete_
&& output_stream_
.get() &&
91 output_stream_
->GetState() == SocketOutputStream::EMPTY
;
94 void ConnectionHandlerImpl::SendMessage(
95 const google::protobuf::MessageLite
& message
) {
96 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
97 DCHECK(handshake_complete_
);
100 CodedOutputStream
coded_output_stream(output_stream_
.get());
101 DVLOG(1) << "Writing proto of size " << message
.ByteSize();
102 int tag
= GetMCSProtoTag(message
);
104 coded_output_stream
.WriteRaw(&tag
, 1);
105 coded_output_stream
.WriteVarint32(message
.ByteSize());
106 message
.SerializeToCodedStream(&coded_output_stream
);
109 if (output_stream_
->Flush(
110 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
111 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
116 void ConnectionHandlerImpl::Login(
117 const google::protobuf::MessageLite
& login_request
) {
118 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
120 const char version_byte
[1] = {kMCSVersion
};
121 const char login_request_tag
[1] = {kLoginRequestTag
};
123 CodedOutputStream
coded_output_stream(output_stream_
.get());
124 coded_output_stream
.WriteRaw(version_byte
, 1);
125 coded_output_stream
.WriteRaw(login_request_tag
, 1);
126 coded_output_stream
.WriteVarint32(login_request
.ByteSize());
127 login_request
.SerializeToCodedStream(&coded_output_stream
);
130 if (output_stream_
->Flush(
131 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
132 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
133 base::ThreadTaskRunnerHandle::Get()->PostTask(
135 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
136 weak_ptr_factory_
.GetWeakPtr()));
139 read_timeout_timer_
.Start(FROM_HERE
,
141 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
142 weak_ptr_factory_
.GetWeakPtr()));
143 WaitForData(MCS_VERSION_TAG_AND_SIZE
);
146 void ConnectionHandlerImpl::OnMessageSent() {
147 if (!output_stream_
.get()) {
148 // The connection has already been closed. Just return.
149 DCHECK(!input_stream_
.get());
150 DCHECK(!read_timeout_timer_
.IsRunning());
154 if (output_stream_
->GetState() != SocketOutputStream::EMPTY
) {
155 int last_error
= output_stream_
->last_error();
157 // If the socket stream had an error, plumb it up, else plumb up FAILED.
158 if (last_error
== net::OK
)
159 last_error
= net::ERR_FAILED
;
160 connection_callback_
.Run(last_error
);
164 write_callback_
.Run();
167 void ConnectionHandlerImpl::GetNextMessage() {
168 DCHECK(SocketInputStream::EMPTY
== input_stream_
->GetState() ||
169 SocketInputStream::READY
== input_stream_
->GetState());
173 WaitForData(MCS_TAG_AND_SIZE
);
176 void ConnectionHandlerImpl::WaitForData(ProcessingState state
) {
177 DVLOG(1) << "Waiting for MCS data: state == " << state
;
179 if (!input_stream_
) {
180 // The connection has already been closed. Just return.
181 DCHECK(!output_stream_
.get());
182 DCHECK(!read_timeout_timer_
.IsRunning());
186 if (input_stream_
->GetState() != SocketInputStream::EMPTY
&&
187 input_stream_
->GetState() != SocketInputStream::READY
) {
188 // An error occurred.
189 int last_error
= output_stream_
->last_error();
191 // If the socket stream had an error, plumb it up, else plumb up FAILED.
192 if (last_error
== net::OK
)
193 last_error
= net::ERR_FAILED
;
194 connection_callback_
.Run(last_error
);
198 // Used to determine whether a Socket::Read is necessary.
199 int min_bytes_needed
= 0;
200 // Used to limit the size of the Socket::Read.
201 int max_bytes_needed
= 0;
204 case MCS_VERSION_TAG_AND_SIZE
:
205 min_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMin
;
206 max_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMax
;
208 case MCS_TAG_AND_SIZE
:
209 min_bytes_needed
= kTagPacketLen
+ kSizePacketLenMin
;
210 max_bytes_needed
= kTagPacketLen
+ kSizePacketLenMax
;
213 min_bytes_needed
= size_packet_so_far_
+ 1;
214 max_bytes_needed
= kSizePacketLenMax
;
216 case MCS_PROTO_BYTES
:
217 read_timeout_timer_
.Reset();
218 if (message_size_
< kDefaultDataPacketLimit
) {
219 // No variability in the message size, set both to the same.
220 min_bytes_needed
= message_size_
;
221 max_bytes_needed
= message_size_
;
223 int bytes_left
= message_size_
- payload_input_buffer_
.size();
224 if (bytes_left
> kDefaultDataPacketLimit
)
225 bytes_left
= kDefaultDataPacketLimit
;
226 min_bytes_needed
= bytes_left
;
227 max_bytes_needed
= bytes_left
;
233 DCHECK_GE(max_bytes_needed
, min_bytes_needed
);
235 int unread_byte_count
= input_stream_
->UnreadByteCount();
236 if (min_bytes_needed
> unread_byte_count
&&
237 input_stream_
->Refresh(
238 base::Bind(&ConnectionHandlerImpl::WaitForData
,
239 weak_ptr_factory_
.GetWeakPtr(),
241 max_bytes_needed
- unread_byte_count
) == net::ERR_IO_PENDING
) {
245 // Check for refresh errors.
246 if (input_stream_
->GetState() != SocketInputStream::READY
) {
247 // An error occurred.
248 int last_error
= input_stream_
->last_error();
250 // If the socket stream had an error, plumb it up, else plumb up FAILED.
251 if (last_error
== net::OK
)
252 last_error
= net::ERR_FAILED
;
253 connection_callback_
.Run(last_error
);
257 // Check whether read is complete, or needs to be continued (
258 // SocketInputStream::Refresh can finish without reading all the data).
259 if (input_stream_
->UnreadByteCount() < min_bytes_needed
) {
260 DVLOG(1) << "Socket read finished prematurely. Waiting for "
261 << min_bytes_needed
- input_stream_
->UnreadByteCount()
263 base::ThreadTaskRunnerHandle::Get()->PostTask(
265 base::Bind(&ConnectionHandlerImpl::WaitForData
,
266 weak_ptr_factory_
.GetWeakPtr(),
271 // Received enough bytes, process them.
272 DVLOG(1) << "Processing MCS data: state == " << state
;
274 case MCS_VERSION_TAG_AND_SIZE
:
277 case MCS_TAG_AND_SIZE
:
283 case MCS_PROTO_BYTES
:
291 void ConnectionHandlerImpl::OnGotVersion() {
294 CodedInputStream
coded_input_stream(input_stream_
.get());
295 coded_input_stream
.ReadRaw(&version
, 1);
297 // TODO(zea): remove this when the server is ready.
298 if (version
< kMCSVersion
&& version
!= 38) {
299 LOG(ERROR
) << "Invalid GCM version response: " << static_cast<int>(version
);
300 connection_callback_
.Run(net::ERR_FAILED
);
304 input_stream_
->RebuildBuffer();
306 // Process the LoginResponse message tag.
310 void ConnectionHandlerImpl::OnGotMessageTag() {
311 if (input_stream_
->GetState() != SocketInputStream::READY
) {
312 LOG(ERROR
) << "Failed to receive protobuf tag.";
313 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
318 CodedInputStream
coded_input_stream(input_stream_
.get());
319 coded_input_stream
.ReadRaw(&message_tag_
, 1);
322 DVLOG(1) << "Received proto of type "
323 << static_cast<unsigned int>(message_tag_
);
325 if (!read_timeout_timer_
.IsRunning()) {
326 read_timeout_timer_
.Start(FROM_HERE
,
328 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
329 weak_ptr_factory_
.GetWeakPtr()));
334 void ConnectionHandlerImpl::OnGotMessageSize() {
335 if (input_stream_
->GetState() != SocketInputStream::READY
) {
336 LOG(ERROR
) << "Failed to receive message size.";
337 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
341 int prev_byte_count
= input_stream_
->UnreadByteCount();
343 CodedInputStream
coded_input_stream(input_stream_
.get());
344 if (!coded_input_stream
.ReadVarint32(&message_size_
)) {
345 DVLOG(1) << "Expecting another message size byte.";
346 if (prev_byte_count
>= kSizePacketLenMax
) {
347 // Already had enough bytes, something else went wrong.
348 LOG(ERROR
) << "Failed to process message size";
349 connection_callback_
.Run(net::ERR_FILE_TOO_BIG
);
352 // Back up by the amount read.
353 int bytes_read
= prev_byte_count
- input_stream_
->UnreadByteCount();
354 input_stream_
->BackUp(bytes_read
);
355 size_packet_so_far_
= bytes_read
;
356 WaitForData(MCS_SIZE
);
361 DVLOG(1) << "Proto size: " << message_size_
;
362 size_packet_so_far_
= 0;
363 payload_input_buffer_
.clear();
365 if (message_size_
> 0)
366 WaitForData(MCS_PROTO_BYTES
);
371 void ConnectionHandlerImpl::OnGotMessageBytes() {
372 read_timeout_timer_
.Stop();
373 scoped_ptr
<google::protobuf::MessageLite
> protobuf(
374 BuildProtobufFromTag(message_tag_
));
375 // Messages with no content are valid; just use the default protobuf for
377 if (protobuf
.get() && message_size_
== 0) {
378 base::ThreadTaskRunnerHandle::Get()->PostTask(
380 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
381 weak_ptr_factory_
.GetWeakPtr()));
382 read_callback_
.Run(protobuf
.Pass());
386 if (input_stream_
->GetState() != SocketInputStream::READY
) {
387 LOG(ERROR
) << "Failed to extract protobuf bytes of type "
388 << static_cast<unsigned int>(message_tag_
);
389 // Reset the connection.
390 connection_callback_
.Run(net::ERR_FAILED
);
394 if (!protobuf
.get()) {
395 LOG(ERROR
) << "Received message of invalid type "
396 << static_cast<unsigned int>(message_tag_
);
397 connection_callback_
.Run(net::ERR_INVALID_ARGUMENT
);
401 if (message_size_
< kDefaultDataPacketLimit
) {
402 CodedInputStream
coded_input_stream(input_stream_
.get());
403 if (!protobuf
->ParsePartialFromCodedStream(&coded_input_stream
)) {
404 LOG(ERROR
) << "Unable to parse GCM message of type "
405 << static_cast<unsigned int>(message_tag_
);
406 // Reset the connection.
407 connection_callback_
.Run(net::ERR_FAILED
);
411 // Copy any data in the input stream onto the end of the buffer.
412 const void* data_ptr
= NULL
;
414 input_stream_
->Next(&data_ptr
, &size
);
415 payload_input_buffer_
.insert(payload_input_buffer_
.end(),
416 static_cast<const uint8
*>(data_ptr
),
417 static_cast<const uint8
*>(data_ptr
) + size
);
418 DCHECK_LE(payload_input_buffer_
.size(), message_size_
);
420 if (payload_input_buffer_
.size() == message_size_
) {
421 ArrayInputStream
buffer_input_stream(payload_input_buffer_
.data(),
422 payload_input_buffer_
.size());
423 CodedInputStream
coded_input_stream(&buffer_input_stream
);
424 if (!protobuf
->ParsePartialFromCodedStream(&coded_input_stream
)) {
425 LOG(ERROR
) << "Unable to parse GCM message of type "
426 << static_cast<unsigned int>(message_tag_
);
427 // Reset the connection.
428 connection_callback_
.Run(net::ERR_FAILED
);
432 // Continue reading data.
433 DVLOG(1) << "Continuing data read. Buffer size is "
434 << payload_input_buffer_
.size()
435 << ", expecting " << message_size_
;
436 input_stream_
->RebuildBuffer();
438 read_timeout_timer_
.Start(FROM_HERE
,
440 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
441 weak_ptr_factory_
.GetWeakPtr()));
442 WaitForData(MCS_PROTO_BYTES
);
447 input_stream_
->RebuildBuffer();
448 base::ThreadTaskRunnerHandle::Get()->PostTask(
450 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
451 weak_ptr_factory_
.GetWeakPtr()));
452 if (message_tag_
== kLoginResponseTag
) {
453 if (handshake_complete_
) {
454 LOG(ERROR
) << "Unexpected login response.";
456 handshake_complete_
= true;
457 DVLOG(1) << "GCM Handshake complete.";
458 connection_callback_
.Run(net::OK
);
461 read_callback_
.Run(protobuf
.Pass());
464 void ConnectionHandlerImpl::OnTimeout() {
465 LOG(ERROR
) << "Timed out waiting for GCM Protocol buffer.";
467 connection_callback_
.Run(net::ERR_TIMED_OUT
);
470 void ConnectionHandlerImpl::CloseConnection() {
471 DVLOG(1) << "Closing connection.";
472 read_timeout_timer_
.Stop();
474 socket_
->Disconnect();
476 handshake_complete_
= false;
479 size_packet_so_far_
= 0;
480 payload_input_buffer_
.clear();
481 input_stream_
.reset();
482 output_stream_
.reset();
483 weak_ptr_factory_
.InvalidateWeakPtrs();