1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/test/chromedriver/net/websocket.h"
7 #include "base/base64.h"
9 #include "base/bind_helpers.h"
10 #include "base/memory/scoped_vector.h"
11 #include "base/rand_util.h"
12 #include "base/sha1.h"
13 #include "base/stringprintf.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "net/base/address_list.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/ip_endpoint.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/net_util.h"
20 #include "net/http/http_response_headers.h"
21 #include "net/http/http_util.h"
22 #include "net/websockets/websocket_frame.h"
24 WebSocket::WebSocket(const GURL
& url
, WebSocketListener
* listener
)
28 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
29 read_buffer_(new net::IOBufferWithSize(4096)) {
30 net::IPAddressNumber address
;
31 CHECK(net::ParseIPLiteralToNumber(url_
.HostNoBrackets(), &address
));
33 base::StringToInt(url_
.port(), &port
);
34 net::AddressList
addresses(net::IPEndPoint(address
, port
));
35 net::NetLog::Source source
;
36 socket_
.reset(new net::TCPClientSocket(addresses
, NULL
, source
));
39 WebSocket::~WebSocket() {
40 CHECK(thread_checker_
.CalledOnValidThread());
43 void WebSocket::Connect(const net::CompletionCallback
& callback
) {
44 CHECK(thread_checker_
.CalledOnValidThread());
45 CHECK_EQ(INITIALIZED
, state_
);
47 connect_callback_
= callback
;
48 int code
= socket_
->Connect(base::Bind(
49 &WebSocket::OnSocketConnect
, base::Unretained(this)));
50 if (code
!= net::ERR_IO_PENDING
)
51 OnSocketConnect(code
);
54 bool WebSocket::Send(const std::string
& message
) {
55 CHECK(thread_checker_
.CalledOnValidThread());
59 net::WebSocketFrameHeader
header(net::WebSocketFrameHeader::kOpCodeText
);
62 header
.payload_length
= message
.length();
63 int header_size
= net::GetWebSocketFrameHeaderSize(header
);
64 net::WebSocketMaskingKey masking_key
= net::GenerateWebSocketMaskingKey();
65 std::string header_str
;
66 header_str
.resize(header_size
);
67 CHECK_EQ(header_size
, net::WriteWebSocketFrameHeader(
68 header
, &masking_key
, &header_str
[0], header_str
.length()));
70 std::string masked_message
= message
;
71 net::MaskWebSocketFramePayload(
72 masking_key
, 0, &masked_message
[0], masked_message
.length());
73 Write(header_str
+ masked_message
);
77 void WebSocket::OnSocketConnect(int code
) {
78 if (code
!= net::OK
) {
83 CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_
));
84 std::string handshake
= base::StringPrintf(
87 "Upgrade: websocket\r\n"
88 "Connection: Upgrade\r\n"
89 "Sec-WebSocket-Key: %s\r\n"
90 "Sec-WebSocket-Version: 13\r\n"
91 "Pragma: no-cache\r\n"
92 "Cache-Control: no-cache\r\n"
101 void WebSocket::Write(const std::string
& data
) {
102 pending_write_
+= data
;
103 if (!write_buffer_
->BytesRemaining())
104 ContinueWritingIfNecessary();
107 void WebSocket::OnWrite(int code
) {
108 if (!socket_
->IsConnected()) {
109 // Supposedly if |StreamSocket| is closed, the error code may be undefined.
110 Close(net::ERR_FAILED
);
118 write_buffer_
->DidConsume(code
);
119 ContinueWritingIfNecessary();
122 void WebSocket::ContinueWritingIfNecessary() {
123 if (!write_buffer_
->BytesRemaining()) {
124 if (pending_write_
.empty())
126 write_buffer_
= new net::DrainableIOBuffer(
127 new net::StringIOBuffer(pending_write_
),
128 pending_write_
.length());
129 pending_write_
.clear();
131 int code
= socket_
->Write(
133 write_buffer_
->BytesRemaining(),
134 base::Bind(&WebSocket::OnWrite
, base::Unretained(this)));
135 if (code
!= net::ERR_IO_PENDING
)
139 void WebSocket::Read() {
140 int code
= socket_
->Read(
142 read_buffer_
->size(),
143 base::Bind(&WebSocket::OnRead
, base::Unretained(this)));
144 if (code
!= net::ERR_IO_PENDING
)
148 void WebSocket::OnRead(int code
) {
150 Close(code
? code
: net::ERR_FAILED
);
154 if (state_
== CONNECTING
)
155 OnReadDuringHandshake(read_buffer_
->data(), code
);
156 else if (state_
== OPEN
)
157 OnReadDuringOpen(read_buffer_
->data(), code
);
159 if (state_
!= CLOSED
)
163 void WebSocket::OnReadDuringHandshake(const char* data
, int len
) {
164 handshake_response_
+= std::string(data
, len
);
165 int headers_end
= net::HttpUtil::LocateEndOfHeaders(
166 handshake_response_
.data(), handshake_response_
.size(), 0);
167 if (headers_end
== -1)
170 const char kMagicKey
[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
171 std::string websocket_accept
;
172 CHECK(base::Base64Encode(base::SHA1HashString(sec_key_
+ kMagicKey
),
174 scoped_refptr
<net::HttpResponseHeaders
> headers(
175 new net::HttpResponseHeaders(
176 net::HttpUtil::AssembleRawHeaders(
177 handshake_response_
.data(), headers_end
)));
178 if (headers
->response_code() != 101 ||
179 !headers
->HasHeaderValue("Upgrade", "WebSocket") ||
180 !headers
->HasHeaderValue("Connection", "Upgrade") ||
181 !headers
->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept
)) {
182 Close(net::ERR_FAILED
);
185 std::string leftover_message
= handshake_response_
.substr(headers_end
);
186 handshake_response_
.clear();
189 InvokeConnectCallback(net::OK
);
190 if (!leftover_message
.empty())
191 OnReadDuringOpen(leftover_message
.c_str(), leftover_message
.length());
194 void WebSocket::OnReadDuringOpen(const char* data
, int len
) {
195 ScopedVector
<net::WebSocketFrameChunk
> frame_chunks
;
196 CHECK(parser_
.Decode(data
, len
, &frame_chunks
));
197 for (size_t i
= 0; i
< frame_chunks
.size(); ++i
) {
198 scoped_refptr
<net::IOBufferWithSize
> buffer
= frame_chunks
[i
]->data
;
200 next_message_
+= std::string(buffer
->data(), buffer
->size());
201 if (frame_chunks
[i
]->final_chunk
) {
202 listener_
->OnMessageReceived(next_message_
);
203 next_message_
.clear();
208 void WebSocket::InvokeConnectCallback(int code
) {
209 net::CompletionCallback temp
= connect_callback_
;
210 connect_callback_
.Reset();
211 CHECK(!temp
.is_null());
215 void WebSocket::Close(int code
) {
216 socket_
->Disconnect();
217 if (!connect_callback_
.is_null())
218 InvokeConnectCallback(code
);
220 listener_
->OnClose();