Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / chrome / test / chromedriver / net / websocket.cc
blobd62f7fba9809e5e270836e9608e38a5088b3871d
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/websocket.h"
7 #include <string.h>
9 #include "base/base64.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_vector.h"
13 #include "base/rand_util.h"
14 #include "base/sha1.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "net/base/address_list.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/base/net_util.h"
22 #include "net/base/sys_addrinfo.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
25 #include "net/websockets/websocket_frame.h"
27 #if defined(OS_WIN)
28 #include <Winsock2.h>
29 #endif
31 namespace {
33 bool ResolveHost(const std::string& host, net::IPAddressNumber* address) {
34 struct addrinfo hints;
35 memset(&hints, 0, sizeof(hints));
36 hints.ai_family = AF_UNSPEC;
37 hints.ai_socktype = SOCK_STREAM;
39 struct addrinfo* result;
40 if (getaddrinfo(host.c_str(), NULL, &hints, &result))
41 return false;
43 for (struct addrinfo* addr = result; addr; addr = addr->ai_next) {
44 if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6) {
45 net::IPEndPoint end_point;
46 if (!end_point.FromSockAddr(addr->ai_addr, addr->ai_addrlen)) {
47 freeaddrinfo(result);
48 return false;
50 *address = end_point.address();
53 freeaddrinfo(result);
54 return true;
57 } // namespace
59 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener)
60 : url_(url),
61 listener_(listener),
62 state_(INITIALIZED),
63 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
64 read_buffer_(new net::IOBufferWithSize(4096)) {}
66 WebSocket::~WebSocket() {
67 CHECK(thread_checker_.CalledOnValidThread());
70 void WebSocket::Connect(const net::CompletionCallback& callback) {
71 CHECK(thread_checker_.CalledOnValidThread());
72 CHECK_EQ(INITIALIZED, state_);
74 net::IPAddressNumber address;
75 if (!net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)) {
76 if (!ResolveHost(url_.HostNoBrackets(), &address)) {
77 callback.Run(net::ERR_ADDRESS_UNREACHABLE);
78 return;
81 net::AddressList addresses(
82 net::IPEndPoint(address, static_cast<uint16>(url_.EffectiveIntPort())));
83 net::NetLog::Source source;
84 socket_.reset(new net::TCPClientSocket(addresses, NULL, source));
86 state_ = CONNECTING;
87 connect_callback_ = callback;
88 int code = socket_->Connect(base::Bind(
89 &WebSocket::OnSocketConnect, base::Unretained(this)));
90 if (code != net::ERR_IO_PENDING)
91 OnSocketConnect(code);
94 bool WebSocket::Send(const std::string& message) {
95 CHECK(thread_checker_.CalledOnValidThread());
96 if (state_ != OPEN)
97 return false;
99 net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText);
100 header.final = true;
101 header.masked = true;
102 header.payload_length = message.length();
103 int header_size = net::GetWebSocketFrameHeaderSize(header);
104 net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey();
105 std::string header_str;
106 header_str.resize(header_size);
107 CHECK_EQ(header_size, net::WriteWebSocketFrameHeader(
108 header, &masking_key, &header_str[0], header_str.length()));
110 std::string masked_message = message;
111 net::MaskWebSocketFramePayload(
112 masking_key, 0, &masked_message[0], masked_message.length());
113 Write(header_str + masked_message);
114 return true;
117 void WebSocket::OnSocketConnect(int code) {
118 if (code != net::OK) {
119 Close(code);
120 return;
123 base::Base64Encode(base::RandBytesAsString(16), &sec_key_);
124 std::string handshake = base::StringPrintf(
125 "GET %s HTTP/1.1\r\n"
126 "Host: %s\r\n"
127 "Upgrade: websocket\r\n"
128 "Connection: Upgrade\r\n"
129 "Sec-WebSocket-Key: %s\r\n"
130 "Sec-WebSocket-Version: 13\r\n"
131 "Pragma: no-cache\r\n"
132 "Cache-Control: no-cache\r\n"
133 "\r\n",
134 url_.path().c_str(),
135 url_.host().c_str(),
136 sec_key_.c_str());
137 Write(handshake);
138 Read();
141 void WebSocket::Write(const std::string& data) {
142 pending_write_ += data;
143 if (!write_buffer_->BytesRemaining())
144 ContinueWritingIfNecessary();
147 void WebSocket::OnWrite(int code) {
148 if (!socket_->IsConnected()) {
149 // Supposedly if |StreamSocket| is closed, the error code may be undefined.
150 Close(net::ERR_FAILED);
151 return;
153 if (code < 0) {
154 Close(code);
155 return;
158 write_buffer_->DidConsume(code);
159 ContinueWritingIfNecessary();
162 void WebSocket::ContinueWritingIfNecessary() {
163 if (!write_buffer_->BytesRemaining()) {
164 if (pending_write_.empty())
165 return;
166 write_buffer_ = new net::DrainableIOBuffer(
167 new net::StringIOBuffer(pending_write_),
168 pending_write_.length());
169 pending_write_.clear();
171 int code =
172 socket_->Write(write_buffer_.get(),
173 write_buffer_->BytesRemaining(),
174 base::Bind(&WebSocket::OnWrite, base::Unretained(this)));
175 if (code != net::ERR_IO_PENDING)
176 OnWrite(code);
179 void WebSocket::Read() {
180 int code =
181 socket_->Read(read_buffer_.get(),
182 read_buffer_->size(),
183 base::Bind(&WebSocket::OnRead, base::Unretained(this)));
184 if (code != net::ERR_IO_PENDING)
185 OnRead(code);
188 void WebSocket::OnRead(int code) {
189 if (code <= 0) {
190 Close(code ? code : net::ERR_FAILED);
191 return;
194 if (state_ == CONNECTING)
195 OnReadDuringHandshake(read_buffer_->data(), code);
196 else if (state_ == OPEN)
197 OnReadDuringOpen(read_buffer_->data(), code);
199 if (state_ != CLOSED)
200 Read();
203 void WebSocket::OnReadDuringHandshake(const char* data, int len) {
204 handshake_response_ += std::string(data, len);
205 int headers_end = net::HttpUtil::LocateEndOfHeaders(
206 handshake_response_.data(), handshake_response_.size(), 0);
207 if (headers_end == -1)
208 return;
210 const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
211 std::string websocket_accept;
212 base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey),
213 &websocket_accept);
214 scoped_refptr<net::HttpResponseHeaders> headers(
215 new net::HttpResponseHeaders(
216 net::HttpUtil::AssembleRawHeaders(
217 handshake_response_.data(), headers_end)));
218 if (headers->response_code() != 101 ||
219 !headers->HasHeaderValue("Upgrade", "WebSocket") ||
220 !headers->HasHeaderValue("Connection", "Upgrade") ||
221 !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) {
222 Close(net::ERR_FAILED);
223 return;
225 std::string leftover_message = handshake_response_.substr(headers_end);
226 handshake_response_.clear();
227 sec_key_.clear();
228 state_ = OPEN;
229 InvokeConnectCallback(net::OK);
230 if (!leftover_message.empty())
231 OnReadDuringOpen(leftover_message.c_str(), leftover_message.length());
234 void WebSocket::OnReadDuringOpen(const char* data, int len) {
235 ScopedVector<net::WebSocketFrameChunk> frame_chunks;
236 CHECK(parser_.Decode(data, len, &frame_chunks));
237 for (size_t i = 0; i < frame_chunks.size(); ++i) {
238 scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data;
239 if (buffer.get())
240 next_message_ += std::string(buffer->data(), buffer->size());
241 if (frame_chunks[i]->final_chunk) {
242 listener_->OnMessageReceived(next_message_);
243 next_message_.clear();
248 void WebSocket::InvokeConnectCallback(int code) {
249 net::CompletionCallback temp = connect_callback_;
250 connect_callback_.Reset();
251 CHECK(!temp.is_null());
252 temp.Run(code);
255 void WebSocket::Close(int code) {
256 socket_->Disconnect();
257 if (!connect_callback_.is_null())
258 InvokeConnectCallback(code);
259 if (state_ == OPEN)
260 listener_->OnClose();
262 state_ = CLOSED;