1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
15 using namespace google::protobuf::io
;
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen
= 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen
= 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin
= 1;
27 const int kSizePacketLenMax
= 2;
29 // The current MCS protocol version.
30 const int kMCSVersion
= 41;
34 ConnectionHandlerImpl::ConnectionHandlerImpl(
35 base::TimeDelta read_timeout
,
36 const ProtoReceivedCallback
& read_callback
,
37 const ProtoSentCallback
& write_callback
,
38 const ConnectionChangedCallback
& connection_callback
)
39 : read_timeout_(read_timeout
),
41 handshake_complete_(false),
44 read_callback_(read_callback
),
45 write_callback_(write_callback
),
46 connection_callback_(connection_callback
),
47 weak_ptr_factory_(this) {
50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
53 void ConnectionHandlerImpl::Init(
54 const mcs_proto::LoginRequest
& login_request
,
55 net::StreamSocket
* socket
) {
56 DCHECK(!read_callback_
.is_null());
57 DCHECK(!write_callback_
.is_null());
58 DCHECK(!connection_callback_
.is_null());
60 // Invalidate any previously outstanding reads.
61 weak_ptr_factory_
.InvalidateWeakPtrs();
63 handshake_complete_
= false;
67 input_stream_
.reset(new SocketInputStream(socket_
));
68 output_stream_
.reset(new SocketOutputStream(socket_
));
73 void ConnectionHandlerImpl::Reset() {
77 bool ConnectionHandlerImpl::CanSendMessage() const {
78 return handshake_complete_
&& output_stream_
.get() &&
79 output_stream_
->GetState() == SocketOutputStream::EMPTY
;
82 void ConnectionHandlerImpl::SendMessage(
83 const google::protobuf::MessageLite
& message
) {
84 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
85 DCHECK(handshake_complete_
);
88 CodedOutputStream
coded_output_stream(output_stream_
.get());
89 DVLOG(1) << "Writing proto of size " << message
.ByteSize();
90 int tag
= GetMCSProtoTag(message
);
92 coded_output_stream
.WriteRaw(&tag
, 1);
93 coded_output_stream
.WriteVarint32(message
.ByteSize());
94 message
.SerializeToCodedStream(&coded_output_stream
);
97 if (output_stream_
->Flush(
98 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
99 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
104 void ConnectionHandlerImpl::Login(
105 const google::protobuf::MessageLite
& login_request
) {
106 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
108 const char version_byte
[1] = {kMCSVersion
};
109 const char login_request_tag
[1] = {kLoginRequestTag
};
111 CodedOutputStream
coded_output_stream(output_stream_
.get());
112 coded_output_stream
.WriteRaw(version_byte
, 1);
113 coded_output_stream
.WriteRaw(login_request_tag
, 1);
114 coded_output_stream
.WriteVarint32(login_request
.ByteSize());
115 login_request
.SerializeToCodedStream(&coded_output_stream
);
118 if (output_stream_
->Flush(
119 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
120 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
121 base::MessageLoop::current()->PostTask(
123 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
124 weak_ptr_factory_
.GetWeakPtr()));
127 read_timeout_timer_
.Start(FROM_HERE
,
129 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
130 weak_ptr_factory_
.GetWeakPtr()));
131 WaitForData(MCS_VERSION_TAG_AND_SIZE
);
134 void ConnectionHandlerImpl::OnMessageSent() {
135 if (!output_stream_
.get()) {
136 // The connection has already been closed. Just return.
137 DCHECK(!input_stream_
.get());
138 DCHECK(!read_timeout_timer_
.IsRunning());
142 if (output_stream_
->GetState() != SocketOutputStream::EMPTY
) {
143 int last_error
= output_stream_
->last_error();
145 // If the socket stream had an error, plumb it up, else plumb up FAILED.
146 if (last_error
== net::OK
)
147 last_error
= net::ERR_FAILED
;
148 connection_callback_
.Run(last_error
);
152 write_callback_
.Run();
155 void ConnectionHandlerImpl::GetNextMessage() {
156 DCHECK(SocketInputStream::EMPTY
== input_stream_
->GetState() ||
157 SocketInputStream::READY
== input_stream_
->GetState());
161 WaitForData(MCS_TAG_AND_SIZE
);
164 void ConnectionHandlerImpl::WaitForData(ProcessingState state
) {
165 DVLOG(1) << "Waiting for MCS data: state == " << state
;
167 if (!input_stream_
) {
168 // The connection has already been closed. Just return.
169 DCHECK(!output_stream_
.get());
170 DCHECK(!read_timeout_timer_
.IsRunning());
174 if (input_stream_
->GetState() != SocketInputStream::EMPTY
&&
175 input_stream_
->GetState() != SocketInputStream::READY
) {
176 // An error occurred.
177 int last_error
= output_stream_
->last_error();
179 // If the socket stream had an error, plumb it up, else plumb up FAILED.
180 if (last_error
== net::OK
)
181 last_error
= net::ERR_FAILED
;
182 connection_callback_
.Run(last_error
);
186 // Used to determine whether a Socket::Read is necessary.
187 size_t min_bytes_needed
= 0;
188 // Used to limit the size of the Socket::Read.
189 size_t max_bytes_needed
= 0;
192 case MCS_VERSION_TAG_AND_SIZE
:
193 min_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMin
;
194 max_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMax
;
196 case MCS_TAG_AND_SIZE
:
197 min_bytes_needed
= kTagPacketLen
+ kSizePacketLenMin
;
198 max_bytes_needed
= kTagPacketLen
+ kSizePacketLenMax
;
201 // If in this state, the minimum size packet length must already have been
202 // insufficient, so set both to the max length.
203 min_bytes_needed
= kSizePacketLenMax
;
204 max_bytes_needed
= kSizePacketLenMax
;
206 case MCS_PROTO_BYTES
:
207 read_timeout_timer_
.Reset();
208 // No variability in the message size, set both to the same.
209 min_bytes_needed
= message_size_
;
210 max_bytes_needed
= message_size_
;
215 DCHECK_GE(max_bytes_needed
, min_bytes_needed
);
217 size_t unread_byte_count
= input_stream_
->UnreadByteCount();
218 if (min_bytes_needed
> unread_byte_count
&&
219 input_stream_
->Refresh(
220 base::Bind(&ConnectionHandlerImpl::WaitForData
,
221 weak_ptr_factory_
.GetWeakPtr(),
223 max_bytes_needed
- unread_byte_count
) == net::ERR_IO_PENDING
) {
227 // Check for refresh errors.
228 if (input_stream_
->GetState() != SocketInputStream::READY
) {
229 // An error occurred.
230 int last_error
= input_stream_
->last_error();
232 // If the socket stream had an error, plumb it up, else plumb up FAILED.
233 if (last_error
== net::OK
)
234 last_error
= net::ERR_FAILED
;
235 connection_callback_
.Run(last_error
);
239 // Check whether read is complete, or needs to be continued (
240 // SocketInputStream::Refresh can finish without reading all the data).
241 if (input_stream_
->UnreadByteCount() < min_bytes_needed
) {
242 DVLOG(1) << "Socket read finished prematurely. Waiting for "
243 << min_bytes_needed
- input_stream_
->UnreadByteCount()
245 base::MessageLoop::current()->PostTask(
247 base::Bind(&ConnectionHandlerImpl::WaitForData
,
248 weak_ptr_factory_
.GetWeakPtr(),
253 // Received enough bytes, process them.
254 DVLOG(1) << "Processing MCS data: state == " << state
;
256 case MCS_VERSION_TAG_AND_SIZE
:
259 case MCS_TAG_AND_SIZE
:
265 case MCS_PROTO_BYTES
:
273 void ConnectionHandlerImpl::OnGotVersion() {
276 CodedInputStream
coded_input_stream(input_stream_
.get());
277 coded_input_stream
.ReadRaw(&version
, 1);
279 // TODO(zea): remove this when the server is ready.
280 if (version
< kMCSVersion
&& version
!= 38) {
281 LOG(ERROR
) << "Invalid GCM version response: " << static_cast<int>(version
);
282 connection_callback_
.Run(net::ERR_FAILED
);
286 input_stream_
->RebuildBuffer();
288 // Process the LoginResponse message tag.
292 void ConnectionHandlerImpl::OnGotMessageTag() {
293 if (input_stream_
->GetState() != SocketInputStream::READY
) {
294 LOG(ERROR
) << "Failed to receive protobuf tag.";
295 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
300 CodedInputStream
coded_input_stream(input_stream_
.get());
301 coded_input_stream
.ReadRaw(&message_tag_
, 1);
304 DVLOG(1) << "Received proto of type "
305 << static_cast<unsigned int>(message_tag_
);
307 if (!read_timeout_timer_
.IsRunning()) {
308 read_timeout_timer_
.Start(FROM_HERE
,
310 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
311 weak_ptr_factory_
.GetWeakPtr()));
316 void ConnectionHandlerImpl::OnGotMessageSize() {
317 if (input_stream_
->GetState() != SocketInputStream::READY
) {
318 LOG(ERROR
) << "Failed to receive message size.";
319 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
323 bool need_another_byte
= false;
324 int prev_byte_count
= input_stream_
->ByteCount();
326 CodedInputStream
coded_input_stream(input_stream_
.get());
327 if (!coded_input_stream
.ReadVarint32(&message_size_
))
328 need_another_byte
= true;
331 if (need_another_byte
) {
332 DVLOG(1) << "Expecting another message size byte.";
333 if (prev_byte_count
>= kSizePacketLenMax
) {
334 // Already had enough bytes, something else went wrong.
335 LOG(ERROR
) << "Failed to process message size.";
336 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
339 // Back up by the amount read (should always be 1 byte).
340 int bytes_read
= prev_byte_count
- input_stream_
->ByteCount();
341 DCHECK_EQ(bytes_read
, 1);
342 input_stream_
->BackUp(bytes_read
);
343 WaitForData(MCS_FULL_SIZE
);
347 DVLOG(1) << "Proto size: " << message_size_
;
349 if (message_size_
> 0)
350 WaitForData(MCS_PROTO_BYTES
);
355 void ConnectionHandlerImpl::OnGotMessageBytes() {
356 read_timeout_timer_
.Stop();
357 scoped_ptr
<google::protobuf::MessageLite
> protobuf(
358 BuildProtobufFromTag(message_tag_
));
359 // Messages with no content are valid; just use the default protobuf for
361 if (protobuf
.get() && message_size_
== 0) {
362 base::MessageLoop::current()->PostTask(
364 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
365 weak_ptr_factory_
.GetWeakPtr()));
366 read_callback_
.Run(protobuf
.Pass());
370 if (!protobuf
.get() ||
371 input_stream_
->GetState() != SocketInputStream::READY
) {
372 LOG(ERROR
) << "Failed to extract protobuf bytes of type "
373 << static_cast<unsigned int>(message_tag_
);
374 // Reset the connection.
375 connection_callback_
.Run(net::ERR_FAILED
);
380 CodedInputStream
coded_input_stream(input_stream_
.get());
381 if (!protobuf
->ParsePartialFromCodedStream(&coded_input_stream
)) {
382 LOG(ERROR
) << "Unable to parse GCM message of type "
383 << static_cast<unsigned int>(message_tag_
);
384 // Reset the connection.
385 connection_callback_
.Run(net::ERR_FAILED
);
390 input_stream_
->RebuildBuffer();
391 base::MessageLoop::current()->PostTask(
393 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
394 weak_ptr_factory_
.GetWeakPtr()));
395 if (message_tag_
== kLoginResponseTag
) {
396 if (handshake_complete_
) {
397 LOG(ERROR
) << "Unexpected login response.";
399 handshake_complete_
= true;
400 DVLOG(1) << "GCM Handshake complete.";
401 connection_callback_
.Run(net::OK
);
404 read_callback_
.Run(protobuf
.Pass());
407 void ConnectionHandlerImpl::OnTimeout() {
408 LOG(ERROR
) << "Timed out waiting for GCM Protocol buffer.";
410 connection_callback_
.Run(net::ERR_TIMED_OUT
);
413 void ConnectionHandlerImpl::CloseConnection() {
414 DVLOG(1) << "Closing connection.";
415 read_timeout_timer_
.Stop();
417 socket_
->Disconnect();
419 handshake_complete_
= false;
422 input_stream_
.reset();
423 output_stream_
.reset();
424 weak_ptr_factory_
.InvalidateWeakPtrs();