Roll src/third_party/WebKit 3aea697:d9c6159 (svn 201973:201974)
[chromium-blink-merge.git] / google_apis / gcm / engine / connection_handler_impl.cc
blob4d3fc82745c7c9d7dc5f9b326af0973b706bbde6
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/location.h"
8 #include "base/thread_task_runner_handle.h"
9 #include "google/protobuf/io/coded_stream.h"
10 #include "google/protobuf/io/zero_copy_stream_impl_lite.h"
11 #include "google_apis/gcm/base/mcs_util.h"
12 #include "google_apis/gcm/base/socket_stream.h"
13 #include "google_apis/gcm/protocol/mcs.pb.h"
14 #include "net/base/net_errors.h"
15 #include "net/socket/stream_socket.h"
17 using namespace google::protobuf::io;
19 namespace gcm {
21 namespace {
23 // # of bytes a MCS version packet consumes.
24 const int kVersionPacketLen = 1;
25 // # of bytes a tag packet consumes.
26 const int kTagPacketLen = 1;
27 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
28 // (the msb in each byte is reserved for denoting whether more bytes follow).
29 // Although the protocol only allows for 4KiB payloads currently, and the socket
30 // stream buffer is only of size 8KiB, it's possible for certain applications to
31 // have larger message sizes. When payload is larger than 4KiB, an temporary
32 // in-memory buffer is used instead of the normal in-place socket stream buffer.
33 const int kSizePacketLenMin = 1;
34 const int kSizePacketLenMax = 5;
36 // The normal limit for a data packet is 4KiB. Any data packet with a size
37 // larger than this uses the temporary in-memory buffer,
38 const int kDefaultDataPacketLimit = 1024 * 4;
40 // The current MCS protocol version.
41 const int kMCSVersion = 41;
43 } // namespace
45 ConnectionHandlerImpl::ConnectionHandlerImpl(
46 base::TimeDelta read_timeout,
47 const ProtoReceivedCallback& read_callback,
48 const ProtoSentCallback& write_callback,
49 const ConnectionChangedCallback& connection_callback)
50 : read_timeout_(read_timeout),
51 socket_(NULL),
52 handshake_complete_(false),
53 message_tag_(0),
54 message_size_(0),
55 read_callback_(read_callback),
56 write_callback_(write_callback),
57 connection_callback_(connection_callback),
58 size_packet_so_far_(0),
59 weak_ptr_factory_(this) {
62 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
65 void ConnectionHandlerImpl::Init(
66 const mcs_proto::LoginRequest& login_request,
67 net::StreamSocket* socket) {
68 DCHECK(!read_callback_.is_null());
69 DCHECK(!write_callback_.is_null());
70 DCHECK(!connection_callback_.is_null());
72 // Invalidate any previously outstanding reads.
73 weak_ptr_factory_.InvalidateWeakPtrs();
75 handshake_complete_ = false;
76 message_tag_ = 0;
77 message_size_ = 0;
78 socket_ = socket;
79 input_stream_.reset(new SocketInputStream(socket_));
80 output_stream_.reset(new SocketOutputStream(socket_));
82 Login(login_request);
85 void ConnectionHandlerImpl::Reset() {
86 CloseConnection();
89 bool ConnectionHandlerImpl::CanSendMessage() const {
90 return handshake_complete_ && output_stream_.get() &&
91 output_stream_->GetState() == SocketOutputStream::EMPTY;
94 void ConnectionHandlerImpl::SendMessage(
95 const google::protobuf::MessageLite& message) {
96 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
97 DCHECK(handshake_complete_);
100 CodedOutputStream coded_output_stream(output_stream_.get());
101 DVLOG(1) << "Writing proto of size " << message.ByteSize();
102 int tag = GetMCSProtoTag(message);
103 DCHECK_NE(tag, -1);
104 coded_output_stream.WriteRaw(&tag, 1);
105 coded_output_stream.WriteVarint32(message.ByteSize());
106 message.SerializeToCodedStream(&coded_output_stream);
109 if (output_stream_->Flush(
110 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
111 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
112 OnMessageSent();
116 void ConnectionHandlerImpl::Login(
117 const google::protobuf::MessageLite& login_request) {
118 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
120 const char version_byte[1] = {kMCSVersion};
121 const char login_request_tag[1] = {kLoginRequestTag};
123 CodedOutputStream coded_output_stream(output_stream_.get());
124 coded_output_stream.WriteRaw(version_byte, 1);
125 coded_output_stream.WriteRaw(login_request_tag, 1);
126 coded_output_stream.WriteVarint32(login_request.ByteSize());
127 login_request.SerializeToCodedStream(&coded_output_stream);
130 if (output_stream_->Flush(
131 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
132 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
133 base::ThreadTaskRunnerHandle::Get()->PostTask(
134 FROM_HERE,
135 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
136 weak_ptr_factory_.GetWeakPtr()));
139 read_timeout_timer_.Start(FROM_HERE,
140 read_timeout_,
141 base::Bind(&ConnectionHandlerImpl::OnTimeout,
142 weak_ptr_factory_.GetWeakPtr()));
143 WaitForData(MCS_VERSION_TAG_AND_SIZE);
146 void ConnectionHandlerImpl::OnMessageSent() {
147 if (!output_stream_.get()) {
148 // The connection has already been closed. Just return.
149 DCHECK(!input_stream_.get());
150 DCHECK(!read_timeout_timer_.IsRunning());
151 return;
154 if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
155 int last_error = output_stream_->last_error();
156 CloseConnection();
157 // If the socket stream had an error, plumb it up, else plumb up FAILED.
158 if (last_error == net::OK)
159 last_error = net::ERR_FAILED;
160 connection_callback_.Run(last_error);
161 return;
164 write_callback_.Run();
167 void ConnectionHandlerImpl::GetNextMessage() {
168 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
169 SocketInputStream::READY == input_stream_->GetState());
170 message_tag_ = 0;
171 message_size_ = 0;
173 WaitForData(MCS_TAG_AND_SIZE);
176 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
177 DVLOG(1) << "Waiting for MCS data: state == " << state;
179 if (!input_stream_) {
180 // The connection has already been closed. Just return.
181 DCHECK(!output_stream_.get());
182 DCHECK(!read_timeout_timer_.IsRunning());
183 return;
186 if (input_stream_->GetState() != SocketInputStream::EMPTY &&
187 input_stream_->GetState() != SocketInputStream::READY) {
188 // An error occurred.
189 int last_error = output_stream_->last_error();
190 CloseConnection();
191 // If the socket stream had an error, plumb it up, else plumb up FAILED.
192 if (last_error == net::OK)
193 last_error = net::ERR_FAILED;
194 connection_callback_.Run(last_error);
195 return;
198 // Used to determine whether a Socket::Read is necessary.
199 int min_bytes_needed = 0;
200 // Used to limit the size of the Socket::Read.
201 int max_bytes_needed = 0;
203 switch(state) {
204 case MCS_VERSION_TAG_AND_SIZE:
205 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
206 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
207 break;
208 case MCS_TAG_AND_SIZE:
209 min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
210 max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
211 break;
212 case MCS_SIZE:
213 min_bytes_needed = size_packet_so_far_ + 1;
214 max_bytes_needed = kSizePacketLenMax;
215 break;
216 case MCS_PROTO_BYTES:
217 read_timeout_timer_.Reset();
218 if (message_size_ < kDefaultDataPacketLimit) {
219 // No variability in the message size, set both to the same.
220 min_bytes_needed = message_size_;
221 max_bytes_needed = message_size_;
222 } else {
223 int bytes_left = message_size_ - payload_input_buffer_.size();
224 if (bytes_left > kDefaultDataPacketLimit)
225 bytes_left = kDefaultDataPacketLimit;
226 min_bytes_needed = bytes_left;
227 max_bytes_needed = bytes_left;
229 break;
230 default:
231 NOTREACHED();
233 DCHECK_GE(max_bytes_needed, min_bytes_needed);
235 int unread_byte_count = input_stream_->UnreadByteCount();
236 if (min_bytes_needed > unread_byte_count &&
237 input_stream_->Refresh(
238 base::Bind(&ConnectionHandlerImpl::WaitForData,
239 weak_ptr_factory_.GetWeakPtr(),
240 state),
241 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
242 return;
245 // Check for refresh errors.
246 if (input_stream_->GetState() != SocketInputStream::READY) {
247 // An error occurred.
248 int last_error = input_stream_->last_error();
249 CloseConnection();
250 // If the socket stream had an error, plumb it up, else plumb up FAILED.
251 if (last_error == net::OK)
252 last_error = net::ERR_FAILED;
253 connection_callback_.Run(last_error);
254 return;
257 // Check whether read is complete, or needs to be continued (
258 // SocketInputStream::Refresh can finish without reading all the data).
259 if (input_stream_->UnreadByteCount() < min_bytes_needed) {
260 DVLOG(1) << "Socket read finished prematurely. Waiting for "
261 << min_bytes_needed - input_stream_->UnreadByteCount()
262 << " more bytes.";
263 base::ThreadTaskRunnerHandle::Get()->PostTask(
264 FROM_HERE,
265 base::Bind(&ConnectionHandlerImpl::WaitForData,
266 weak_ptr_factory_.GetWeakPtr(),
267 MCS_PROTO_BYTES));
268 return;
271 // Received enough bytes, process them.
272 DVLOG(1) << "Processing MCS data: state == " << state;
273 switch(state) {
274 case MCS_VERSION_TAG_AND_SIZE:
275 OnGotVersion();
276 break;
277 case MCS_TAG_AND_SIZE:
278 OnGotMessageTag();
279 break;
280 case MCS_SIZE:
281 OnGotMessageSize();
282 break;
283 case MCS_PROTO_BYTES:
284 OnGotMessageBytes();
285 break;
286 default:
287 NOTREACHED();
291 void ConnectionHandlerImpl::OnGotVersion() {
292 uint8 version = 0;
294 CodedInputStream coded_input_stream(input_stream_.get());
295 coded_input_stream.ReadRaw(&version, 1);
297 // TODO(zea): remove this when the server is ready.
298 if (version < kMCSVersion && version != 38) {
299 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
300 connection_callback_.Run(net::ERR_FAILED);
301 return;
304 input_stream_->RebuildBuffer();
306 // Process the LoginResponse message tag.
307 OnGotMessageTag();
310 void ConnectionHandlerImpl::OnGotMessageTag() {
311 if (input_stream_->GetState() != SocketInputStream::READY) {
312 LOG(ERROR) << "Failed to receive protobuf tag.";
313 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
314 return;
318 CodedInputStream coded_input_stream(input_stream_.get());
319 coded_input_stream.ReadRaw(&message_tag_, 1);
322 DVLOG(1) << "Received proto of type "
323 << static_cast<unsigned int>(message_tag_);
325 if (!read_timeout_timer_.IsRunning()) {
326 read_timeout_timer_.Start(FROM_HERE,
327 read_timeout_,
328 base::Bind(&ConnectionHandlerImpl::OnTimeout,
329 weak_ptr_factory_.GetWeakPtr()));
331 OnGotMessageSize();
334 void ConnectionHandlerImpl::OnGotMessageSize() {
335 if (input_stream_->GetState() != SocketInputStream::READY) {
336 LOG(ERROR) << "Failed to receive message size.";
337 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
338 return;
341 int prev_byte_count = input_stream_->UnreadByteCount();
343 CodedInputStream coded_input_stream(input_stream_.get());
344 if (!coded_input_stream.ReadVarint32(&message_size_)) {
345 DVLOG(1) << "Expecting another message size byte.";
346 if (prev_byte_count >= kSizePacketLenMax) {
347 // Already had enough bytes, something else went wrong.
348 LOG(ERROR) << "Failed to process message size";
349 connection_callback_.Run(net::ERR_FILE_TOO_BIG);
350 return;
352 // Back up by the amount read.
353 int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
354 input_stream_->BackUp(bytes_read);
355 size_packet_so_far_ = bytes_read;
356 WaitForData(MCS_SIZE);
357 return;
361 DVLOG(1) << "Proto size: " << message_size_;
362 size_packet_so_far_ = 0;
363 payload_input_buffer_.clear();
365 if (message_size_ > 0)
366 WaitForData(MCS_PROTO_BYTES);
367 else
368 OnGotMessageBytes();
371 void ConnectionHandlerImpl::OnGotMessageBytes() {
372 read_timeout_timer_.Stop();
373 scoped_ptr<google::protobuf::MessageLite> protobuf(
374 BuildProtobufFromTag(message_tag_));
375 // Messages with no content are valid; just use the default protobuf for
376 // that tag.
377 if (protobuf.get() && message_size_ == 0) {
378 base::ThreadTaskRunnerHandle::Get()->PostTask(
379 FROM_HERE,
380 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
381 weak_ptr_factory_.GetWeakPtr()));
382 read_callback_.Run(protobuf.Pass());
383 return;
386 if (input_stream_->GetState() != SocketInputStream::READY) {
387 LOG(ERROR) << "Failed to extract protobuf bytes of type "
388 << static_cast<unsigned int>(message_tag_);
389 // Reset the connection.
390 connection_callback_.Run(net::ERR_FAILED);
391 return;
394 if (!protobuf.get()) {
395 LOG(ERROR) << "Received message of invalid type "
396 << static_cast<unsigned int>(message_tag_);
397 connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
398 return;
401 if (message_size_ < kDefaultDataPacketLimit) {
402 CodedInputStream coded_input_stream(input_stream_.get());
403 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
404 LOG(ERROR) << "Unable to parse GCM message of type "
405 << static_cast<unsigned int>(message_tag_);
406 // Reset the connection.
407 connection_callback_.Run(net::ERR_FAILED);
408 return;
410 } else {
411 // Copy any data in the input stream onto the end of the buffer.
412 const void* data_ptr = NULL;
413 int size = 0;
414 input_stream_->Next(&data_ptr, &size);
415 payload_input_buffer_.insert(payload_input_buffer_.end(),
416 static_cast<const uint8*>(data_ptr),
417 static_cast<const uint8*>(data_ptr) + size);
418 DCHECK_LE(payload_input_buffer_.size(), message_size_);
420 if (payload_input_buffer_.size() == message_size_) {
421 ArrayInputStream buffer_input_stream(payload_input_buffer_.data(),
422 payload_input_buffer_.size());
423 CodedInputStream coded_input_stream(&buffer_input_stream);
424 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
425 LOG(ERROR) << "Unable to parse GCM message of type "
426 << static_cast<unsigned int>(message_tag_);
427 // Reset the connection.
428 connection_callback_.Run(net::ERR_FAILED);
429 return;
431 } else {
432 // Continue reading data.
433 DVLOG(1) << "Continuing data read. Buffer size is "
434 << payload_input_buffer_.size()
435 << ", expecting " << message_size_;
436 input_stream_->RebuildBuffer();
438 read_timeout_timer_.Start(FROM_HERE,
439 read_timeout_,
440 base::Bind(&ConnectionHandlerImpl::OnTimeout,
441 weak_ptr_factory_.GetWeakPtr()));
442 WaitForData(MCS_PROTO_BYTES);
443 return;
447 input_stream_->RebuildBuffer();
448 base::ThreadTaskRunnerHandle::Get()->PostTask(
449 FROM_HERE,
450 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
451 weak_ptr_factory_.GetWeakPtr()));
452 if (message_tag_ == kLoginResponseTag) {
453 if (handshake_complete_) {
454 LOG(ERROR) << "Unexpected login response.";
455 } else {
456 handshake_complete_ = true;
457 DVLOG(1) << "GCM Handshake complete.";
458 connection_callback_.Run(net::OK);
461 read_callback_.Run(protobuf.Pass());
464 void ConnectionHandlerImpl::OnTimeout() {
465 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
466 CloseConnection();
467 connection_callback_.Run(net::ERR_TIMED_OUT);
470 void ConnectionHandlerImpl::CloseConnection() {
471 DVLOG(1) << "Closing connection.";
472 read_timeout_timer_.Stop();
473 if (socket_)
474 socket_->Disconnect();
475 socket_ = NULL;
476 handshake_complete_ = false;
477 message_tag_ = 0;
478 message_size_ = 0;
479 size_packet_so_far_ = 0;
480 payload_input_buffer_.clear();
481 input_stream_.reset();
482 output_stream_.reset();
483 weak_ptr_factory_.InvalidateWeakPtrs();
486 } // namespace gcm