1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
15 using namespace google::protobuf::io
;
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen
= 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen
= 1;
25 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
26 // (the MSB in each byte is reserved for denoting whether more bytes follow).
27 // But, the protocol only allows for 4KiB payloads, and the socket stream buffer
28 // is only of size 8KiB. As such we should never need more than 2 bytes (max
29 // value of 16KiB). Anything higher than that will result in an error, either
30 // because the socket stream buffer overflowed or too many bytes were required
31 // in the size packet.
32 const int kSizePacketLenMin
= 1;
33 const int kSizePacketLenMax
= 2;
35 // The current MCS protocol version.
36 const int kMCSVersion
= 41;
40 ConnectionHandlerImpl::ConnectionHandlerImpl(
41 base::TimeDelta read_timeout
,
42 const ProtoReceivedCallback
& read_callback
,
43 const ProtoSentCallback
& write_callback
,
44 const ConnectionChangedCallback
& connection_callback
)
45 : read_timeout_(read_timeout
),
47 handshake_complete_(false),
50 read_callback_(read_callback
),
51 write_callback_(write_callback
),
52 connection_callback_(connection_callback
),
53 weak_ptr_factory_(this) {
56 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
59 void ConnectionHandlerImpl::Init(
60 const mcs_proto::LoginRequest
& login_request
,
61 net::StreamSocket
* socket
) {
62 DCHECK(!read_callback_
.is_null());
63 DCHECK(!write_callback_
.is_null());
64 DCHECK(!connection_callback_
.is_null());
66 // Invalidate any previously outstanding reads.
67 weak_ptr_factory_
.InvalidateWeakPtrs();
69 handshake_complete_
= false;
73 input_stream_
.reset(new SocketInputStream(socket_
));
74 output_stream_
.reset(new SocketOutputStream(socket_
));
79 void ConnectionHandlerImpl::Reset() {
83 bool ConnectionHandlerImpl::CanSendMessage() const {
84 return handshake_complete_
&& output_stream_
.get() &&
85 output_stream_
->GetState() == SocketOutputStream::EMPTY
;
88 void ConnectionHandlerImpl::SendMessage(
89 const google::protobuf::MessageLite
& message
) {
90 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
91 DCHECK(handshake_complete_
);
94 CodedOutputStream
coded_output_stream(output_stream_
.get());
95 DVLOG(1) << "Writing proto of size " << message
.ByteSize();
96 int tag
= GetMCSProtoTag(message
);
98 coded_output_stream
.WriteRaw(&tag
, 1);
99 coded_output_stream
.WriteVarint32(message
.ByteSize());
100 message
.SerializeToCodedStream(&coded_output_stream
);
103 if (output_stream_
->Flush(
104 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
105 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
110 void ConnectionHandlerImpl::Login(
111 const google::protobuf::MessageLite
& login_request
) {
112 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
114 const char version_byte
[1] = {kMCSVersion
};
115 const char login_request_tag
[1] = {kLoginRequestTag
};
117 CodedOutputStream
coded_output_stream(output_stream_
.get());
118 coded_output_stream
.WriteRaw(version_byte
, 1);
119 coded_output_stream
.WriteRaw(login_request_tag
, 1);
120 coded_output_stream
.WriteVarint32(login_request
.ByteSize());
121 login_request
.SerializeToCodedStream(&coded_output_stream
);
124 if (output_stream_
->Flush(
125 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
126 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
127 base::MessageLoop::current()->PostTask(
129 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
130 weak_ptr_factory_
.GetWeakPtr()));
133 read_timeout_timer_
.Start(FROM_HERE
,
135 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
136 weak_ptr_factory_
.GetWeakPtr()));
137 WaitForData(MCS_VERSION_TAG_AND_SIZE
);
140 void ConnectionHandlerImpl::OnMessageSent() {
141 if (!output_stream_
.get()) {
142 // The connection has already been closed. Just return.
143 DCHECK(!input_stream_
.get());
144 DCHECK(!read_timeout_timer_
.IsRunning());
148 if (output_stream_
->GetState() != SocketOutputStream::EMPTY
) {
149 int last_error
= output_stream_
->last_error();
151 // If the socket stream had an error, plumb it up, else plumb up FAILED.
152 if (last_error
== net::OK
)
153 last_error
= net::ERR_FAILED
;
154 connection_callback_
.Run(last_error
);
158 write_callback_
.Run();
161 void ConnectionHandlerImpl::GetNextMessage() {
162 DCHECK(SocketInputStream::EMPTY
== input_stream_
->GetState() ||
163 SocketInputStream::READY
== input_stream_
->GetState());
167 WaitForData(MCS_TAG_AND_SIZE
);
170 void ConnectionHandlerImpl::WaitForData(ProcessingState state
) {
171 DVLOG(1) << "Waiting for MCS data: state == " << state
;
173 if (!input_stream_
) {
174 // The connection has already been closed. Just return.
175 DCHECK(!output_stream_
.get());
176 DCHECK(!read_timeout_timer_
.IsRunning());
180 if (input_stream_
->GetState() != SocketInputStream::EMPTY
&&
181 input_stream_
->GetState() != SocketInputStream::READY
) {
182 // An error occurred.
183 int last_error
= output_stream_
->last_error();
185 // If the socket stream had an error, plumb it up, else plumb up FAILED.
186 if (last_error
== net::OK
)
187 last_error
= net::ERR_FAILED
;
188 connection_callback_
.Run(last_error
);
192 // Used to determine whether a Socket::Read is necessary.
193 int min_bytes_needed
= 0;
194 // Used to limit the size of the Socket::Read.
195 int max_bytes_needed
= 0;
198 case MCS_VERSION_TAG_AND_SIZE
:
199 min_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMin
;
200 max_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMax
;
202 case MCS_TAG_AND_SIZE
:
203 min_bytes_needed
= kTagPacketLen
+ kSizePacketLenMin
;
204 max_bytes_needed
= kTagPacketLen
+ kSizePacketLenMax
;
207 // If in this state, the minimum size packet length must already have been
208 // insufficient, so set both to the max length.
209 min_bytes_needed
= kSizePacketLenMax
;
210 max_bytes_needed
= kSizePacketLenMax
;
212 case MCS_PROTO_BYTES
:
213 read_timeout_timer_
.Reset();
214 // No variability in the message size, set both to the same.
215 min_bytes_needed
= message_size_
;
216 max_bytes_needed
= message_size_
;
221 DCHECK_GE(max_bytes_needed
, min_bytes_needed
);
223 int unread_byte_count
= input_stream_
->UnreadByteCount();
224 if (min_bytes_needed
> unread_byte_count
&&
225 input_stream_
->Refresh(
226 base::Bind(&ConnectionHandlerImpl::WaitForData
,
227 weak_ptr_factory_
.GetWeakPtr(),
229 max_bytes_needed
- unread_byte_count
) == net::ERR_IO_PENDING
) {
233 // Check for refresh errors.
234 if (input_stream_
->GetState() != SocketInputStream::READY
) {
235 // An error occurred.
236 int last_error
= input_stream_
->last_error();
238 // If the socket stream had an error, plumb it up, else plumb up FAILED.
239 if (last_error
== net::OK
)
240 last_error
= net::ERR_FAILED
;
241 connection_callback_
.Run(last_error
);
245 // Check whether read is complete, or needs to be continued (
246 // SocketInputStream::Refresh can finish without reading all the data).
247 if (input_stream_
->UnreadByteCount() < min_bytes_needed
) {
248 DVLOG(1) << "Socket read finished prematurely. Waiting for "
249 << min_bytes_needed
- input_stream_
->UnreadByteCount()
251 base::MessageLoop::current()->PostTask(
253 base::Bind(&ConnectionHandlerImpl::WaitForData
,
254 weak_ptr_factory_
.GetWeakPtr(),
259 // Received enough bytes, process them.
260 DVLOG(1) << "Processing MCS data: state == " << state
;
262 case MCS_VERSION_TAG_AND_SIZE
:
265 case MCS_TAG_AND_SIZE
:
271 case MCS_PROTO_BYTES
:
279 void ConnectionHandlerImpl::OnGotVersion() {
282 CodedInputStream
coded_input_stream(input_stream_
.get());
283 coded_input_stream
.ReadRaw(&version
, 1);
285 // TODO(zea): remove this when the server is ready.
286 if (version
< kMCSVersion
&& version
!= 38) {
287 LOG(ERROR
) << "Invalid GCM version response: " << static_cast<int>(version
);
288 connection_callback_
.Run(net::ERR_FAILED
);
292 input_stream_
->RebuildBuffer();
294 // Process the LoginResponse message tag.
298 void ConnectionHandlerImpl::OnGotMessageTag() {
299 if (input_stream_
->GetState() != SocketInputStream::READY
) {
300 LOG(ERROR
) << "Failed to receive protobuf tag.";
301 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
306 CodedInputStream
coded_input_stream(input_stream_
.get());
307 coded_input_stream
.ReadRaw(&message_tag_
, 1);
310 DVLOG(1) << "Received proto of type "
311 << static_cast<unsigned int>(message_tag_
);
313 if (!read_timeout_timer_
.IsRunning()) {
314 read_timeout_timer_
.Start(FROM_HERE
,
316 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
317 weak_ptr_factory_
.GetWeakPtr()));
322 void ConnectionHandlerImpl::OnGotMessageSize() {
323 if (input_stream_
->GetState() != SocketInputStream::READY
) {
324 LOG(ERROR
) << "Failed to receive message size.";
325 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
329 bool need_another_byte
= false;
330 int prev_byte_count
= input_stream_
->UnreadByteCount();
332 CodedInputStream
coded_input_stream(input_stream_
.get());
333 if (!coded_input_stream
.ReadVarint32(&message_size_
))
334 need_another_byte
= true;
337 if (need_another_byte
) {
338 DVLOG(1) << "Expecting another message size byte.";
339 if (prev_byte_count
>= kSizePacketLenMax
) {
340 // Already had enough bytes, something else went wrong.
341 LOG(ERROR
) << "Failed to process message size, too many bytes needed.";
342 connection_callback_
.Run(net::ERR_FILE_TOO_BIG
);
345 // Back up by the amount read (should always be 1 byte).
346 int bytes_read
= prev_byte_count
- input_stream_
->UnreadByteCount();
347 DCHECK_EQ(bytes_read
, 1);
348 input_stream_
->BackUp(bytes_read
);
349 WaitForData(MCS_FULL_SIZE
);
353 DVLOG(1) << "Proto size: " << message_size_
;
355 if (message_size_
> 0)
356 WaitForData(MCS_PROTO_BYTES
);
361 void ConnectionHandlerImpl::OnGotMessageBytes() {
362 read_timeout_timer_
.Stop();
363 scoped_ptr
<google::protobuf::MessageLite
> protobuf(
364 BuildProtobufFromTag(message_tag_
));
365 // Messages with no content are valid; just use the default protobuf for
367 if (protobuf
.get() && message_size_
== 0) {
368 base::MessageLoop::current()->PostTask(
370 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
371 weak_ptr_factory_
.GetWeakPtr()));
372 read_callback_
.Run(protobuf
.Pass());
376 if (input_stream_
->GetState() != SocketInputStream::READY
) {
377 LOG(ERROR
) << "Failed to extract protobuf bytes of type "
378 << static_cast<unsigned int>(message_tag_
);
379 // Reset the connection.
380 connection_callback_
.Run(net::ERR_FAILED
);
384 if (!protobuf
.get()) {
385 LOG(ERROR
) << "Received message of invalid type "
386 << static_cast<unsigned int>(message_tag_
);
387 connection_callback_
.Run(net::ERR_INVALID_ARGUMENT
);
392 CodedInputStream
coded_input_stream(input_stream_
.get());
393 if (!protobuf
->ParsePartialFromCodedStream(&coded_input_stream
)) {
394 LOG(ERROR
) << "Unable to parse GCM message of type "
395 << static_cast<unsigned int>(message_tag_
);
396 // Reset the connection.
397 connection_callback_
.Run(net::ERR_FAILED
);
402 input_stream_
->RebuildBuffer();
403 base::MessageLoop::current()->PostTask(
405 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
406 weak_ptr_factory_
.GetWeakPtr()));
407 if (message_tag_
== kLoginResponseTag
) {
408 if (handshake_complete_
) {
409 LOG(ERROR
) << "Unexpected login response.";
411 handshake_complete_
= true;
412 DVLOG(1) << "GCM Handshake complete.";
413 connection_callback_
.Run(net::OK
);
416 read_callback_
.Run(protobuf
.Pass());
419 void ConnectionHandlerImpl::OnTimeout() {
420 LOG(ERROR
) << "Timed out waiting for GCM Protocol buffer.";
422 connection_callback_
.Run(net::ERR_TIMED_OUT
);
425 void ConnectionHandlerImpl::CloseConnection() {
426 DVLOG(1) << "Closing connection.";
427 read_timeout_timer_
.Stop();
429 socket_
->Disconnect();
431 handshake_complete_
= false;
434 input_stream_
.reset();
435 output_stream_
.reset();
436 weak_ptr_factory_
.InvalidateWeakPtrs();