1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/websocket.h"
9 #include "base/base64.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_vector.h"
13 #include "base/rand_util.h"
14 #include "base/sha1.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "net/base/address_list.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/base/net_util.h"
22 #include "net/base/sys_addrinfo.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
25 #include "net/websockets/websocket_frame.h"
33 bool ResolveHost(const std::string
& host
, net::IPAddressNumber
* address
) {
34 struct addrinfo hints
;
35 memset(&hints
, 0, sizeof(hints
));
36 hints
.ai_family
= AF_UNSPEC
;
37 hints
.ai_socktype
= SOCK_STREAM
;
39 struct addrinfo
* result
;
40 if (getaddrinfo(host
.c_str(), NULL
, &hints
, &result
))
43 for (struct addrinfo
* addr
= result
; addr
; addr
= addr
->ai_next
) {
44 if (addr
->ai_family
== AF_INET
|| addr
->ai_family
== AF_INET6
) {
45 net::IPEndPoint end_point
;
46 if (!end_point
.FromSockAddr(addr
->ai_addr
, addr
->ai_addrlen
)) {
50 *address
= end_point
.address();
59 WebSocket::WebSocket(const GURL
& url
, WebSocketListener
* listener
)
63 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
64 read_buffer_(new net::IOBufferWithSize(4096)) {}
66 WebSocket::~WebSocket() {
67 CHECK(thread_checker_
.CalledOnValidThread());
70 void WebSocket::Connect(const net::CompletionCallback
& callback
) {
71 CHECK(thread_checker_
.CalledOnValidThread());
72 CHECK_EQ(INITIALIZED
, state_
);
74 net::IPAddressNumber address
;
75 if (!net::ParseIPLiteralToNumber(url_
.HostNoBrackets(), &address
)) {
76 if (!ResolveHost(url_
.HostNoBrackets(), &address
)) {
77 callback
.Run(net::ERR_ADDRESS_UNREACHABLE
);
82 base::StringToInt(url_
.port(), &port
);
83 net::AddressList
addresses(net::IPEndPoint(address
, port
));
84 net::NetLog::Source source
;
85 socket_
.reset(new net::TCPClientSocket(addresses
, NULL
, source
));
88 connect_callback_
= callback
;
89 int code
= socket_
->Connect(base::Bind(
90 &WebSocket::OnSocketConnect
, base::Unretained(this)));
91 if (code
!= net::ERR_IO_PENDING
)
92 OnSocketConnect(code
);
95 bool WebSocket::Send(const std::string
& message
) {
96 CHECK(thread_checker_
.CalledOnValidThread());
100 net::WebSocketFrameHeader
header(net::WebSocketFrameHeader::kOpCodeText
);
102 header
.masked
= true;
103 header
.payload_length
= message
.length();
104 int header_size
= net::GetWebSocketFrameHeaderSize(header
);
105 net::WebSocketMaskingKey masking_key
= net::GenerateWebSocketMaskingKey();
106 std::string header_str
;
107 header_str
.resize(header_size
);
108 CHECK_EQ(header_size
, net::WriteWebSocketFrameHeader(
109 header
, &masking_key
, &header_str
[0], header_str
.length()));
111 std::string masked_message
= message
;
112 net::MaskWebSocketFramePayload(
113 masking_key
, 0, &masked_message
[0], masked_message
.length());
114 Write(header_str
+ masked_message
);
118 void WebSocket::OnSocketConnect(int code
) {
119 if (code
!= net::OK
) {
124 base::Base64Encode(base::RandBytesAsString(16), &sec_key_
);
125 std::string handshake
= base::StringPrintf(
126 "GET %s HTTP/1.1\r\n"
128 "Upgrade: websocket\r\n"
129 "Connection: Upgrade\r\n"
130 "Sec-WebSocket-Key: %s\r\n"
131 "Sec-WebSocket-Version: 13\r\n"
132 "Pragma: no-cache\r\n"
133 "Cache-Control: no-cache\r\n"
142 void WebSocket::Write(const std::string
& data
) {
143 pending_write_
+= data
;
144 if (!write_buffer_
->BytesRemaining())
145 ContinueWritingIfNecessary();
148 void WebSocket::OnWrite(int code
) {
149 if (!socket_
->IsConnected()) {
150 // Supposedly if |StreamSocket| is closed, the error code may be undefined.
151 Close(net::ERR_FAILED
);
159 write_buffer_
->DidConsume(code
);
160 ContinueWritingIfNecessary();
163 void WebSocket::ContinueWritingIfNecessary() {
164 if (!write_buffer_
->BytesRemaining()) {
165 if (pending_write_
.empty())
167 write_buffer_
= new net::DrainableIOBuffer(
168 new net::StringIOBuffer(pending_write_
),
169 pending_write_
.length());
170 pending_write_
.clear();
173 socket_
->Write(write_buffer_
.get(),
174 write_buffer_
->BytesRemaining(),
175 base::Bind(&WebSocket::OnWrite
, base::Unretained(this)));
176 if (code
!= net::ERR_IO_PENDING
)
180 void WebSocket::Read() {
182 socket_
->Read(read_buffer_
.get(),
183 read_buffer_
->size(),
184 base::Bind(&WebSocket::OnRead
, base::Unretained(this)));
185 if (code
!= net::ERR_IO_PENDING
)
189 void WebSocket::OnRead(int code
) {
191 Close(code
? code
: net::ERR_FAILED
);
195 if (state_
== CONNECTING
)
196 OnReadDuringHandshake(read_buffer_
->data(), code
);
197 else if (state_
== OPEN
)
198 OnReadDuringOpen(read_buffer_
->data(), code
);
200 if (state_
!= CLOSED
)
204 void WebSocket::OnReadDuringHandshake(const char* data
, int len
) {
205 handshake_response_
+= std::string(data
, len
);
206 int headers_end
= net::HttpUtil::LocateEndOfHeaders(
207 handshake_response_
.data(), handshake_response_
.size(), 0);
208 if (headers_end
== -1)
211 const char kMagicKey
[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
212 std::string websocket_accept
;
213 base::Base64Encode(base::SHA1HashString(sec_key_
+ kMagicKey
),
215 scoped_refptr
<net::HttpResponseHeaders
> headers(
216 new net::HttpResponseHeaders(
217 net::HttpUtil::AssembleRawHeaders(
218 handshake_response_
.data(), headers_end
)));
219 if (headers
->response_code() != 101 ||
220 !headers
->HasHeaderValue("Upgrade", "WebSocket") ||
221 !headers
->HasHeaderValue("Connection", "Upgrade") ||
222 !headers
->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept
)) {
223 Close(net::ERR_FAILED
);
226 std::string leftover_message
= handshake_response_
.substr(headers_end
);
227 handshake_response_
.clear();
230 InvokeConnectCallback(net::OK
);
231 if (!leftover_message
.empty())
232 OnReadDuringOpen(leftover_message
.c_str(), leftover_message
.length());
235 void WebSocket::OnReadDuringOpen(const char* data
, int len
) {
236 ScopedVector
<net::WebSocketFrameChunk
> frame_chunks
;
237 CHECK(parser_
.Decode(data
, len
, &frame_chunks
));
238 for (size_t i
= 0; i
< frame_chunks
.size(); ++i
) {
239 scoped_refptr
<net::IOBufferWithSize
> buffer
= frame_chunks
[i
]->data
;
241 next_message_
+= std::string(buffer
->data(), buffer
->size());
242 if (frame_chunks
[i
]->final_chunk
) {
243 listener_
->OnMessageReceived(next_message_
);
244 next_message_
.clear();
249 void WebSocket::InvokeConnectCallback(int code
) {
250 net::CompletionCallback temp
= connect_callback_
;
251 connect_callback_
.Reset();
252 CHECK(!temp
.is_null());
256 void WebSocket::Close(int code
) {
257 socket_
->Disconnect();
258 if (!connect_callback_
.is_null())
259 InvokeConnectCallback(code
);
261 listener_
->OnClose();