Re-land: C++ readability review
[chromium-blink-merge.git] / google_apis / gcm / engine / connection_handler_impl.cc
blobccb6362936bdd71ff14b72b1e475d9b3831df9dc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
15 using namespace google::protobuf::io;
17 namespace gcm {
19 namespace {
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
26 // (the MSB in each byte is reserved for denoting whether more bytes follow).
27 // But, the protocol only allows for 4KiB payloads, and the socket stream buffer
28 // is only of size 8KiB. As such we should never need more than 2 bytes (max
29 // value of 16KiB). Anything higher than that will result in an error, either
30 // because the socket stream buffer overflowed or too many bytes were required
31 // in the size packet.
32 const int kSizePacketLenMin = 1;
33 const int kSizePacketLenMax = 2;
35 // The current MCS protocol version.
36 const int kMCSVersion = 41;
38 } // namespace
40 ConnectionHandlerImpl::ConnectionHandlerImpl(
41 base::TimeDelta read_timeout,
42 const ProtoReceivedCallback& read_callback,
43 const ProtoSentCallback& write_callback,
44 const ConnectionChangedCallback& connection_callback)
45 : read_timeout_(read_timeout),
46 socket_(NULL),
47 handshake_complete_(false),
48 message_tag_(0),
49 message_size_(0),
50 read_callback_(read_callback),
51 write_callback_(write_callback),
52 connection_callback_(connection_callback),
53 weak_ptr_factory_(this) {
56 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
59 void ConnectionHandlerImpl::Init(
60 const mcs_proto::LoginRequest& login_request,
61 net::StreamSocket* socket) {
62 DCHECK(!read_callback_.is_null());
63 DCHECK(!write_callback_.is_null());
64 DCHECK(!connection_callback_.is_null());
66 // Invalidate any previously outstanding reads.
67 weak_ptr_factory_.InvalidateWeakPtrs();
69 handshake_complete_ = false;
70 message_tag_ = 0;
71 message_size_ = 0;
72 socket_ = socket;
73 input_stream_.reset(new SocketInputStream(socket_));
74 output_stream_.reset(new SocketOutputStream(socket_));
76 Login(login_request);
79 void ConnectionHandlerImpl::Reset() {
80 CloseConnection();
83 bool ConnectionHandlerImpl::CanSendMessage() const {
84 return handshake_complete_ && output_stream_.get() &&
85 output_stream_->GetState() == SocketOutputStream::EMPTY;
88 void ConnectionHandlerImpl::SendMessage(
89 const google::protobuf::MessageLite& message) {
90 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
91 DCHECK(handshake_complete_);
94 CodedOutputStream coded_output_stream(output_stream_.get());
95 DVLOG(1) << "Writing proto of size " << message.ByteSize();
96 int tag = GetMCSProtoTag(message);
97 DCHECK_NE(tag, -1);
98 coded_output_stream.WriteRaw(&tag, 1);
99 coded_output_stream.WriteVarint32(message.ByteSize());
100 message.SerializeToCodedStream(&coded_output_stream);
103 if (output_stream_->Flush(
104 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
105 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
106 OnMessageSent();
110 void ConnectionHandlerImpl::Login(
111 const google::protobuf::MessageLite& login_request) {
112 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
114 const char version_byte[1] = {kMCSVersion};
115 const char login_request_tag[1] = {kLoginRequestTag};
117 CodedOutputStream coded_output_stream(output_stream_.get());
118 coded_output_stream.WriteRaw(version_byte, 1);
119 coded_output_stream.WriteRaw(login_request_tag, 1);
120 coded_output_stream.WriteVarint32(login_request.ByteSize());
121 login_request.SerializeToCodedStream(&coded_output_stream);
124 if (output_stream_->Flush(
125 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
126 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
127 base::MessageLoop::current()->PostTask(
128 FROM_HERE,
129 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
130 weak_ptr_factory_.GetWeakPtr()));
133 read_timeout_timer_.Start(FROM_HERE,
134 read_timeout_,
135 base::Bind(&ConnectionHandlerImpl::OnTimeout,
136 weak_ptr_factory_.GetWeakPtr()));
137 WaitForData(MCS_VERSION_TAG_AND_SIZE);
140 void ConnectionHandlerImpl::OnMessageSent() {
141 if (!output_stream_.get()) {
142 // The connection has already been closed. Just return.
143 DCHECK(!input_stream_.get());
144 DCHECK(!read_timeout_timer_.IsRunning());
145 return;
148 if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
149 int last_error = output_stream_->last_error();
150 CloseConnection();
151 // If the socket stream had an error, plumb it up, else plumb up FAILED.
152 if (last_error == net::OK)
153 last_error = net::ERR_FAILED;
154 connection_callback_.Run(last_error);
155 return;
158 write_callback_.Run();
161 void ConnectionHandlerImpl::GetNextMessage() {
162 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
163 SocketInputStream::READY == input_stream_->GetState());
164 message_tag_ = 0;
165 message_size_ = 0;
167 WaitForData(MCS_TAG_AND_SIZE);
170 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
171 DVLOG(1) << "Waiting for MCS data: state == " << state;
173 if (!input_stream_) {
174 // The connection has already been closed. Just return.
175 DCHECK(!output_stream_.get());
176 DCHECK(!read_timeout_timer_.IsRunning());
177 return;
180 if (input_stream_->GetState() != SocketInputStream::EMPTY &&
181 input_stream_->GetState() != SocketInputStream::READY) {
182 // An error occurred.
183 int last_error = output_stream_->last_error();
184 CloseConnection();
185 // If the socket stream had an error, plumb it up, else plumb up FAILED.
186 if (last_error == net::OK)
187 last_error = net::ERR_FAILED;
188 connection_callback_.Run(last_error);
189 return;
192 // Used to determine whether a Socket::Read is necessary.
193 int min_bytes_needed = 0;
194 // Used to limit the size of the Socket::Read.
195 int max_bytes_needed = 0;
197 switch(state) {
198 case MCS_VERSION_TAG_AND_SIZE:
199 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
200 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
201 break;
202 case MCS_TAG_AND_SIZE:
203 min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
204 max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
205 break;
206 case MCS_FULL_SIZE:
207 // If in this state, the minimum size packet length must already have been
208 // insufficient, so set both to the max length.
209 min_bytes_needed = kSizePacketLenMax;
210 max_bytes_needed = kSizePacketLenMax;
211 break;
212 case MCS_PROTO_BYTES:
213 read_timeout_timer_.Reset();
214 // No variability in the message size, set both to the same.
215 min_bytes_needed = message_size_;
216 max_bytes_needed = message_size_;
217 break;
218 default:
219 NOTREACHED();
221 DCHECK_GE(max_bytes_needed, min_bytes_needed);
223 int unread_byte_count = input_stream_->UnreadByteCount();
224 if (min_bytes_needed > unread_byte_count &&
225 input_stream_->Refresh(
226 base::Bind(&ConnectionHandlerImpl::WaitForData,
227 weak_ptr_factory_.GetWeakPtr(),
228 state),
229 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
230 return;
233 // Check for refresh errors.
234 if (input_stream_->GetState() != SocketInputStream::READY) {
235 // An error occurred.
236 int last_error = input_stream_->last_error();
237 CloseConnection();
238 // If the socket stream had an error, plumb it up, else plumb up FAILED.
239 if (last_error == net::OK)
240 last_error = net::ERR_FAILED;
241 connection_callback_.Run(last_error);
242 return;
245 // Check whether read is complete, or needs to be continued (
246 // SocketInputStream::Refresh can finish without reading all the data).
247 if (input_stream_->UnreadByteCount() < min_bytes_needed) {
248 DVLOG(1) << "Socket read finished prematurely. Waiting for "
249 << min_bytes_needed - input_stream_->UnreadByteCount()
250 << " more bytes.";
251 base::MessageLoop::current()->PostTask(
252 FROM_HERE,
253 base::Bind(&ConnectionHandlerImpl::WaitForData,
254 weak_ptr_factory_.GetWeakPtr(),
255 MCS_PROTO_BYTES));
256 return;
259 // Received enough bytes, process them.
260 DVLOG(1) << "Processing MCS data: state == " << state;
261 switch(state) {
262 case MCS_VERSION_TAG_AND_SIZE:
263 OnGotVersion();
264 break;
265 case MCS_TAG_AND_SIZE:
266 OnGotMessageTag();
267 break;
268 case MCS_FULL_SIZE:
269 OnGotMessageSize();
270 break;
271 case MCS_PROTO_BYTES:
272 OnGotMessageBytes();
273 break;
274 default:
275 NOTREACHED();
279 void ConnectionHandlerImpl::OnGotVersion() {
280 uint8 version = 0;
282 CodedInputStream coded_input_stream(input_stream_.get());
283 coded_input_stream.ReadRaw(&version, 1);
285 // TODO(zea): remove this when the server is ready.
286 if (version < kMCSVersion && version != 38) {
287 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
288 connection_callback_.Run(net::ERR_FAILED);
289 return;
292 input_stream_->RebuildBuffer();
294 // Process the LoginResponse message tag.
295 OnGotMessageTag();
298 void ConnectionHandlerImpl::OnGotMessageTag() {
299 if (input_stream_->GetState() != SocketInputStream::READY) {
300 LOG(ERROR) << "Failed to receive protobuf tag.";
301 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
302 return;
306 CodedInputStream coded_input_stream(input_stream_.get());
307 coded_input_stream.ReadRaw(&message_tag_, 1);
310 DVLOG(1) << "Received proto of type "
311 << static_cast<unsigned int>(message_tag_);
313 if (!read_timeout_timer_.IsRunning()) {
314 read_timeout_timer_.Start(FROM_HERE,
315 read_timeout_,
316 base::Bind(&ConnectionHandlerImpl::OnTimeout,
317 weak_ptr_factory_.GetWeakPtr()));
319 OnGotMessageSize();
322 void ConnectionHandlerImpl::OnGotMessageSize() {
323 if (input_stream_->GetState() != SocketInputStream::READY) {
324 LOG(ERROR) << "Failed to receive message size.";
325 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
326 return;
329 bool need_another_byte = false;
330 int prev_byte_count = input_stream_->UnreadByteCount();
332 CodedInputStream coded_input_stream(input_stream_.get());
333 if (!coded_input_stream.ReadVarint32(&message_size_))
334 need_another_byte = true;
337 if (need_another_byte) {
338 DVLOG(1) << "Expecting another message size byte.";
339 if (prev_byte_count >= kSizePacketLenMax) {
340 // Already had enough bytes, something else went wrong.
341 LOG(ERROR) << "Failed to process message size, too many bytes needed.";
342 connection_callback_.Run(net::ERR_FILE_TOO_BIG);
343 return;
345 // Back up by the amount read (should always be 1 byte).
346 int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
347 DCHECK_EQ(bytes_read, 1);
348 input_stream_->BackUp(bytes_read);
349 WaitForData(MCS_FULL_SIZE);
350 return;
353 DVLOG(1) << "Proto size: " << message_size_;
355 if (message_size_ > 0)
356 WaitForData(MCS_PROTO_BYTES);
357 else
358 OnGotMessageBytes();
361 void ConnectionHandlerImpl::OnGotMessageBytes() {
362 read_timeout_timer_.Stop();
363 scoped_ptr<google::protobuf::MessageLite> protobuf(
364 BuildProtobufFromTag(message_tag_));
365 // Messages with no content are valid; just use the default protobuf for
366 // that tag.
367 if (protobuf.get() && message_size_ == 0) {
368 base::MessageLoop::current()->PostTask(
369 FROM_HERE,
370 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
371 weak_ptr_factory_.GetWeakPtr()));
372 read_callback_.Run(protobuf.Pass());
373 return;
376 if (input_stream_->GetState() != SocketInputStream::READY) {
377 LOG(ERROR) << "Failed to extract protobuf bytes of type "
378 << static_cast<unsigned int>(message_tag_);
379 // Reset the connection.
380 connection_callback_.Run(net::ERR_FAILED);
381 return;
384 if (!protobuf.get()) {
385 LOG(ERROR) << "Received message of invalid type "
386 << static_cast<unsigned int>(message_tag_);
387 connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
388 return;
392 CodedInputStream coded_input_stream(input_stream_.get());
393 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
394 LOG(ERROR) << "Unable to parse GCM message of type "
395 << static_cast<unsigned int>(message_tag_);
396 // Reset the connection.
397 connection_callback_.Run(net::ERR_FAILED);
398 return;
402 input_stream_->RebuildBuffer();
403 base::MessageLoop::current()->PostTask(
404 FROM_HERE,
405 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
406 weak_ptr_factory_.GetWeakPtr()));
407 if (message_tag_ == kLoginResponseTag) {
408 if (handshake_complete_) {
409 LOG(ERROR) << "Unexpected login response.";
410 } else {
411 handshake_complete_ = true;
412 DVLOG(1) << "GCM Handshake complete.";
413 connection_callback_.Run(net::OK);
416 read_callback_.Run(protobuf.Pass());
419 void ConnectionHandlerImpl::OnTimeout() {
420 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
421 CloseConnection();
422 connection_callback_.Run(net::ERR_TIMED_OUT);
425 void ConnectionHandlerImpl::CloseConnection() {
426 DVLOG(1) << "Closing connection.";
427 read_timeout_timer_.Stop();
428 if (socket_)
429 socket_->Disconnect();
430 socket_ = NULL;
431 handshake_complete_ = false;
432 message_tag_ = 0;
433 message_size_ = 0;
434 input_stream_.reset();
435 output_stream_.reset();
436 weak_ptr_factory_.InvalidateWeakPtrs();
439 } // namespace gcm