Suppression for crbug/241044.
[chromium-blink-merge.git] / chrome / test / chromedriver / net / websocket.cc
blobb48c6a3f7008e3580047512118cbacf5651449d1
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/websocket.h"
7 #include "base/base64.h"
8 #include "base/bind.h"
9 #include "base/bind_helpers.h"
10 #include "base/memory/scoped_vector.h"
11 #include "base/rand_util.h"
12 #include "base/sha1.h"
13 #include "base/stringprintf.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "net/base/address_list.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/ip_endpoint.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/net_util.h"
20 #include "net/http/http_response_headers.h"
21 #include "net/http/http_util.h"
22 #include "net/websockets/websocket_frame.h"
24 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener)
25 : url_(url),
26 listener_(listener),
27 state_(INITIALIZED),
28 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
29 read_buffer_(new net::IOBufferWithSize(4096)) {
30 net::IPAddressNumber address;
31 CHECK(net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address));
32 int port = 80;
33 base::StringToInt(url_.port(), &port);
34 net::AddressList addresses(net::IPEndPoint(address, port));
35 net::NetLog::Source source;
36 socket_.reset(new net::TCPClientSocket(addresses, NULL, source));
39 WebSocket::~WebSocket() {
40 CHECK(thread_checker_.CalledOnValidThread());
43 void WebSocket::Connect(const net::CompletionCallback& callback) {
44 CHECK(thread_checker_.CalledOnValidThread());
45 CHECK_EQ(INITIALIZED, state_);
46 state_ = CONNECTING;
47 connect_callback_ = callback;
48 int code = socket_->Connect(base::Bind(
49 &WebSocket::OnSocketConnect, base::Unretained(this)));
50 if (code != net::ERR_IO_PENDING)
51 OnSocketConnect(code);
54 bool WebSocket::Send(const std::string& message) {
55 CHECK(thread_checker_.CalledOnValidThread());
56 if (state_ != OPEN)
57 return false;
59 net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText);
60 header.final = true;
61 header.masked = true;
62 header.payload_length = message.length();
63 int header_size = net::GetWebSocketFrameHeaderSize(header);
64 net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey();
65 std::string header_str;
66 header_str.resize(header_size);
67 CHECK_EQ(header_size, net::WriteWebSocketFrameHeader(
68 header, &masking_key, &header_str[0], header_str.length()));
70 std::string masked_message = message;
71 net::MaskWebSocketFramePayload(
72 masking_key, 0, &masked_message[0], masked_message.length());
73 Write(header_str + masked_message);
74 return true;
77 void WebSocket::OnSocketConnect(int code) {
78 if (code != net::OK) {
79 Close(code);
80 return;
83 CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_));
84 std::string handshake = base::StringPrintf(
85 "GET %s HTTP/1.1\r\n"
86 "Host: %s\r\n"
87 "Upgrade: websocket\r\n"
88 "Connection: Upgrade\r\n"
89 "Sec-WebSocket-Key: %s\r\n"
90 "Sec-WebSocket-Version: 13\r\n"
91 "Pragma: no-cache\r\n"
92 "Cache-Control: no-cache\r\n"
93 "\r\n",
94 url_.path().c_str(),
95 url_.host().c_str(),
96 sec_key_.c_str());
97 Write(handshake);
98 Read();
101 void WebSocket::Write(const std::string& data) {
102 pending_write_ += data;
103 if (!write_buffer_->BytesRemaining())
104 ContinueWritingIfNecessary();
107 void WebSocket::OnWrite(int code) {
108 if (!socket_->IsConnected()) {
109 // Supposedly if |StreamSocket| is closed, the error code may be undefined.
110 Close(net::ERR_FAILED);
111 return;
113 if (code < 0) {
114 Close(code);
115 return;
118 write_buffer_->DidConsume(code);
119 ContinueWritingIfNecessary();
122 void WebSocket::ContinueWritingIfNecessary() {
123 if (!write_buffer_->BytesRemaining()) {
124 if (pending_write_.empty())
125 return;
126 write_buffer_ = new net::DrainableIOBuffer(
127 new net::StringIOBuffer(pending_write_),
128 pending_write_.length());
129 pending_write_.clear();
131 int code = socket_->Write(
132 write_buffer_,
133 write_buffer_->BytesRemaining(),
134 base::Bind(&WebSocket::OnWrite, base::Unretained(this)));
135 if (code != net::ERR_IO_PENDING)
136 OnWrite(code);
139 void WebSocket::Read() {
140 int code = socket_->Read(
141 read_buffer_,
142 read_buffer_->size(),
143 base::Bind(&WebSocket::OnRead, base::Unretained(this)));
144 if (code != net::ERR_IO_PENDING)
145 OnRead(code);
148 void WebSocket::OnRead(int code) {
149 if (code <= 0) {
150 Close(code ? code : net::ERR_FAILED);
151 return;
154 if (state_ == CONNECTING)
155 OnReadDuringHandshake(read_buffer_->data(), code);
156 else if (state_ == OPEN)
157 OnReadDuringOpen(read_buffer_->data(), code);
159 if (state_ != CLOSED)
160 Read();
163 void WebSocket::OnReadDuringHandshake(const char* data, int len) {
164 handshake_response_ += std::string(data, len);
165 int headers_end = net::HttpUtil::LocateEndOfHeaders(
166 handshake_response_.data(), handshake_response_.size(), 0);
167 if (headers_end == -1)
168 return;
170 const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
171 std::string websocket_accept;
172 CHECK(base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey),
173 &websocket_accept));
174 scoped_refptr<net::HttpResponseHeaders> headers(
175 new net::HttpResponseHeaders(
176 net::HttpUtil::AssembleRawHeaders(
177 handshake_response_.data(), headers_end)));
178 if (headers->response_code() != 101 ||
179 !headers->HasHeaderValue("Upgrade", "WebSocket") ||
180 !headers->HasHeaderValue("Connection", "Upgrade") ||
181 !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) {
182 Close(net::ERR_FAILED);
183 return;
185 std::string leftover_message = handshake_response_.substr(headers_end);
186 handshake_response_.clear();
187 sec_key_.clear();
188 state_ = OPEN;
189 InvokeConnectCallback(net::OK);
190 if (!leftover_message.empty())
191 OnReadDuringOpen(leftover_message.c_str(), leftover_message.length());
194 void WebSocket::OnReadDuringOpen(const char* data, int len) {
195 ScopedVector<net::WebSocketFrameChunk> frame_chunks;
196 CHECK(parser_.Decode(data, len, &frame_chunks));
197 for (size_t i = 0; i < frame_chunks.size(); ++i) {
198 scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data;
199 if (buffer)
200 next_message_ += std::string(buffer->data(), buffer->size());
201 if (frame_chunks[i]->final_chunk) {
202 listener_->OnMessageReceived(next_message_);
203 next_message_.clear();
208 void WebSocket::InvokeConnectCallback(int code) {
209 net::CompletionCallback temp = connect_callback_;
210 connect_callback_.Reset();
211 CHECK(!temp.is_null());
212 temp.Run(code);
215 void WebSocket::Close(int code) {
216 socket_->Disconnect();
217 if (!connect_callback_.is_null())
218 InvokeConnectCallback(code);
219 if (state_ == OPEN)
220 listener_->OnClose();
222 state_ = CLOSED;