1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/websocket.h"
9 #include "base/base64.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_vector.h"
13 #include "base/rand_util.h"
14 #include "base/sha1.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "net/base/address_list.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/base/net_util.h"
22 #include "net/base/sys_addrinfo.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
25 #include "net/websockets/websocket_frame.h"
33 bool ResolveHost(const std::string
& host
, net::IPAddressNumber
* address
) {
34 struct addrinfo hints
;
35 memset(&hints
, 0, sizeof(hints
));
36 hints
.ai_family
= AF_UNSPEC
;
37 hints
.ai_socktype
= SOCK_STREAM
;
39 struct addrinfo
* result
;
40 if (getaddrinfo(host
.c_str(), NULL
, &hints
, &result
))
43 for (struct addrinfo
* addr
= result
; addr
; addr
= addr
->ai_next
) {
44 if (addr
->ai_family
== AF_INET
|| addr
->ai_family
== AF_INET6
) {
45 net::IPEndPoint end_point
;
46 if (!end_point
.FromSockAddr(addr
->ai_addr
, addr
->ai_addrlen
)) {
50 *address
= end_point
.address();
59 WebSocket::WebSocket(const GURL
& url
, WebSocketListener
* listener
)
63 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
64 read_buffer_(new net::IOBufferWithSize(4096)) {}
66 WebSocket::~WebSocket() {
67 CHECK(thread_checker_
.CalledOnValidThread());
70 void WebSocket::Connect(const net::CompletionCallback
& callback
) {
71 CHECK(thread_checker_
.CalledOnValidThread());
72 CHECK_EQ(INITIALIZED
, state_
);
74 net::IPAddressNumber address
;
75 if (!net::ParseIPLiteralToNumber(url_
.HostNoBrackets(), &address
)) {
76 if (!ResolveHost(url_
.HostNoBrackets(), &address
)) {
77 callback
.Run(net::ERR_ADDRESS_UNREACHABLE
);
81 net::AddressList
addresses(
82 net::IPEndPoint(address
, static_cast<uint16
>(url_
.EffectiveIntPort())));
83 net::NetLog::Source source
;
84 socket_
.reset(new net::TCPClientSocket(addresses
, NULL
, source
));
87 connect_callback_
= callback
;
88 int code
= socket_
->Connect(base::Bind(
89 &WebSocket::OnSocketConnect
, base::Unretained(this)));
90 if (code
!= net::ERR_IO_PENDING
)
91 OnSocketConnect(code
);
94 bool WebSocket::Send(const std::string
& message
) {
95 CHECK(thread_checker_
.CalledOnValidThread());
99 net::WebSocketFrameHeader
header(net::WebSocketFrameHeader::kOpCodeText
);
101 header
.masked
= true;
102 header
.payload_length
= message
.length();
103 int header_size
= net::GetWebSocketFrameHeaderSize(header
);
104 net::WebSocketMaskingKey masking_key
= net::GenerateWebSocketMaskingKey();
105 std::string header_str
;
106 header_str
.resize(header_size
);
107 CHECK_EQ(header_size
, net::WriteWebSocketFrameHeader(
108 header
, &masking_key
, &header_str
[0], header_str
.length()));
110 std::string masked_message
= message
;
111 net::MaskWebSocketFramePayload(
112 masking_key
, 0, &masked_message
[0], masked_message
.length());
113 Write(header_str
+ masked_message
);
117 void WebSocket::OnSocketConnect(int code
) {
118 if (code
!= net::OK
) {
123 base::Base64Encode(base::RandBytesAsString(16), &sec_key_
);
124 std::string handshake
= base::StringPrintf(
125 "GET %s HTTP/1.1\r\n"
127 "Upgrade: websocket\r\n"
128 "Connection: Upgrade\r\n"
129 "Sec-WebSocket-Key: %s\r\n"
130 "Sec-WebSocket-Version: 13\r\n"
131 "Pragma: no-cache\r\n"
132 "Cache-Control: no-cache\r\n"
141 void WebSocket::Write(const std::string
& data
) {
142 pending_write_
+= data
;
143 if (!write_buffer_
->BytesRemaining())
144 ContinueWritingIfNecessary();
147 void WebSocket::OnWrite(int code
) {
148 if (!socket_
->IsConnected()) {
149 // Supposedly if |StreamSocket| is closed, the error code may be undefined.
150 Close(net::ERR_FAILED
);
158 write_buffer_
->DidConsume(code
);
159 ContinueWritingIfNecessary();
162 void WebSocket::ContinueWritingIfNecessary() {
163 if (!write_buffer_
->BytesRemaining()) {
164 if (pending_write_
.empty())
166 write_buffer_
= new net::DrainableIOBuffer(
167 new net::StringIOBuffer(pending_write_
),
168 pending_write_
.length());
169 pending_write_
.clear();
172 socket_
->Write(write_buffer_
.get(),
173 write_buffer_
->BytesRemaining(),
174 base::Bind(&WebSocket::OnWrite
, base::Unretained(this)));
175 if (code
!= net::ERR_IO_PENDING
)
179 void WebSocket::Read() {
181 socket_
->Read(read_buffer_
.get(),
182 read_buffer_
->size(),
183 base::Bind(&WebSocket::OnRead
, base::Unretained(this)));
184 if (code
!= net::ERR_IO_PENDING
)
188 void WebSocket::OnRead(int code
) {
190 Close(code
? code
: net::ERR_FAILED
);
194 if (state_
== CONNECTING
)
195 OnReadDuringHandshake(read_buffer_
->data(), code
);
196 else if (state_
== OPEN
)
197 OnReadDuringOpen(read_buffer_
->data(), code
);
199 if (state_
!= CLOSED
)
203 void WebSocket::OnReadDuringHandshake(const char* data
, int len
) {
204 handshake_response_
+= std::string(data
, len
);
205 int headers_end
= net::HttpUtil::LocateEndOfHeaders(
206 handshake_response_
.data(), handshake_response_
.size(), 0);
207 if (headers_end
== -1)
210 const char kMagicKey
[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
211 std::string websocket_accept
;
212 base::Base64Encode(base::SHA1HashString(sec_key_
+ kMagicKey
),
214 scoped_refptr
<net::HttpResponseHeaders
> headers(
215 new net::HttpResponseHeaders(
216 net::HttpUtil::AssembleRawHeaders(
217 handshake_response_
.data(), headers_end
)));
218 if (headers
->response_code() != 101 ||
219 !headers
->HasHeaderValue("Upgrade", "WebSocket") ||
220 !headers
->HasHeaderValue("Connection", "Upgrade") ||
221 !headers
->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept
)) {
222 Close(net::ERR_FAILED
);
225 std::string leftover_message
= handshake_response_
.substr(headers_end
);
226 handshake_response_
.clear();
229 InvokeConnectCallback(net::OK
);
230 if (!leftover_message
.empty())
231 OnReadDuringOpen(leftover_message
.c_str(), leftover_message
.length());
234 void WebSocket::OnReadDuringOpen(const char* data
, int len
) {
235 ScopedVector
<net::WebSocketFrameChunk
> frame_chunks
;
236 CHECK(parser_
.Decode(data
, len
, &frame_chunks
));
237 for (size_t i
= 0; i
< frame_chunks
.size(); ++i
) {
238 scoped_refptr
<net::IOBufferWithSize
> buffer
= frame_chunks
[i
]->data
;
240 next_message_
+= std::string(buffer
->data(), buffer
->size());
241 if (frame_chunks
[i
]->final_chunk
) {
242 listener_
->OnMessageReceived(next_message_
);
243 next_message_
.clear();
248 void WebSocket::InvokeConnectCallback(int code
) {
249 net::CompletionCallback temp
= connect_callback_
;
250 connect_callback_
.Reset();
251 CHECK(!temp
.is_null());
255 void WebSocket::Close(int code
) {
256 socket_
->Disconnect();
257 if (!connect_callback_
.is_null())
258 InvokeConnectCallback(code
);
260 listener_
->OnClose();