1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
10 #include "net/base/completion_callback.h"
11 #include "net/base/net_errors.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_request_info.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/http/http_response_info.h"
16 #include "net/socket/client_socket_handle.h"
17 #include "net/socket/socket_test_util.h"
18 #include "net/websockets/websocket_basic_handshake_stream.h"
19 #include "net/websockets/websocket_stream.h"
20 #include "net/websockets/websocket_test_util.h"
21 #include "testing/gtest/include/gtest/gtest.h"
23 #include "url/origin.h"
28 // This class encapsulates the details of creating a mock ClientSocketHandle.
29 class MockClientSocketHandleFactory
{
31 MockClientSocketHandleFactory()
32 : pool_(1, 1, socket_factory_maker_
.factory()) {}
34 // The created socket expects |expect_written| to be written to the socket,
35 // and will respond with |return_to_read|. The test will fail if the expected
36 // text is not written, or if all the bytes are not read.
37 scoped_ptr
<ClientSocketHandle
> CreateClientSocketHandle(
38 const std::string
& expect_written
,
39 const std::string
& return_to_read
) {
40 socket_factory_maker_
.SetExpectations(expect_written
, return_to_read
);
41 scoped_ptr
<ClientSocketHandle
> socket_handle(new ClientSocketHandle
);
44 scoped_refptr
<MockTransportSocketParams
>(),
49 return socket_handle
.Pass();
53 WebSocketMockClientSocketFactoryMaker socket_factory_maker_
;
54 MockTransportClientSocketPool pool_
;
56 DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory
);
59 class TestConnectDelegate
: public WebSocketStream::ConnectDelegate
{
61 ~TestConnectDelegate() override
{}
63 void OnSuccess(scoped_ptr
<WebSocketStream
> stream
) override
{}
64 void OnFailure(const std::string
& failure_message
) override
{}
65 void OnStartOpeningHandshake(
66 scoped_ptr
<WebSocketHandshakeRequestInfo
> request
) override
{}
67 void OnFinishOpeningHandshake(
68 scoped_ptr
<WebSocketHandshakeResponseInfo
> response
) override
{}
69 void OnSSLCertificateError(
70 scoped_ptr
<WebSocketEventInterface::SSLErrorCallbacks
>
72 const SSLInfo
& ssl_info
,
73 bool fatal
) override
{}
76 class WebSocketHandshakeStreamCreateHelperTest
: public ::testing::Test
{
78 scoped_ptr
<WebSocketStream
> CreateAndInitializeStream(
79 const std::vector
<std::string
>& sub_protocols
,
80 const std::string
& extra_request_headers
,
81 const std::string
& extra_response_headers
) {
82 static const char kOrigin
[] = "http://localhost";
83 WebSocketHandshakeStreamCreateHelper
create_helper(&connect_delegate_
,
85 create_helper
.set_failure_message(&failure_message_
);
87 scoped_ptr
<ClientSocketHandle
> socket_handle
=
88 socket_handle_factory_
.CreateClientSocketHandle(
89 WebSocketStandardRequest("/", "localhost",
90 url::Origin(GURL(kOrigin
)),
91 extra_request_headers
),
92 WebSocketStandardResponse(extra_response_headers
));
94 scoped_ptr
<WebSocketHandshakeStreamBase
> handshake(
95 create_helper
.CreateBasicStream(socket_handle
.Pass(), false));
97 // If in future the implementation type returned by CreateBasicStream()
98 // changes, this static_cast will be wrong. However, in that case the test
99 // will fail and AddressSanitizer should identify the issue.
100 static_cast<WebSocketBasicHandshakeStream
*>(handshake
.get())
101 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
103 HttpRequestInfo request_info
;
104 request_info
.url
= GURL("ws://localhost/");
105 request_info
.method
= "GET";
106 request_info
.load_flags
= LOAD_DISABLE_CACHE
;
107 int rv
= handshake
->InitializeStream(
108 &request_info
, DEFAULT_PRIORITY
, BoundNetLog(), CompletionCallback());
111 HttpRequestHeaders headers
;
112 headers
.SetHeader("Host", "localhost");
113 headers
.SetHeader("Connection", "Upgrade");
114 headers
.SetHeader("Pragma", "no-cache");
115 headers
.SetHeader("Cache-Control", "no-cache");
116 headers
.SetHeader("Upgrade", "websocket");
117 headers
.SetHeader("Origin", kOrigin
);
118 headers
.SetHeader("Sec-WebSocket-Version", "13");
119 headers
.SetHeader("User-Agent", "");
120 headers
.SetHeader("Accept-Encoding", "gzip, deflate");
121 headers
.SetHeader("Accept-Language", "en-us,fr");
123 HttpResponseInfo response
;
124 TestCompletionCallback dummy
;
126 rv
= handshake
->SendRequest(headers
, &response
, dummy
.callback());
130 rv
= handshake
->ReadResponseHeaders(dummy
.callback());
132 EXPECT_EQ(101, response
.headers
->response_code());
133 EXPECT_TRUE(response
.headers
->HasHeaderValue("Connection", "Upgrade"));
134 EXPECT_TRUE(response
.headers
->HasHeaderValue("Upgrade", "websocket"));
135 return handshake
->Upgrade();
138 MockClientSocketHandleFactory socket_handle_factory_
;
139 TestConnectDelegate connect_delegate_
;
140 std::string failure_message_
;
143 // Confirm that the basic case works as expected.
144 TEST_F(WebSocketHandshakeStreamCreateHelperTest
, BasicStream
) {
145 scoped_ptr
<WebSocketStream
> stream
=
146 CreateAndInitializeStream(std::vector
<std::string
>(), "", "");
147 EXPECT_EQ("", stream
->GetExtensions());
148 EXPECT_EQ("", stream
->GetSubProtocol());
151 // Verify that the sub-protocols are passed through.
152 TEST_F(WebSocketHandshakeStreamCreateHelperTest
, SubProtocols
) {
153 std::vector
<std::string
> sub_protocols
;
154 sub_protocols
.push_back("chat");
155 sub_protocols
.push_back("superchat");
156 scoped_ptr
<WebSocketStream
> stream
= CreateAndInitializeStream(
157 sub_protocols
, "Sec-WebSocket-Protocol: chat, superchat\r\n",
158 "Sec-WebSocket-Protocol: superchat\r\n");
159 EXPECT_EQ("superchat", stream
->GetSubProtocol());
162 // Verify that extension name is available. Bad extension names are tested in
163 // websocket_stream_test.cc.
164 TEST_F(WebSocketHandshakeStreamCreateHelperTest
, Extensions
) {
165 scoped_ptr
<WebSocketStream
> stream
= CreateAndInitializeStream(
166 std::vector
<std::string
>(), "",
167 "Sec-WebSocket-Extensions: permessage-deflate\r\n");
168 EXPECT_EQ("permessage-deflate", stream
->GetExtensions());
171 // Verify that extension parameters are available. Bad parameters are tested in
172 // websocket_stream_test.cc.
173 TEST_F(WebSocketHandshakeStreamCreateHelperTest
, ExtensionParameters
) {
174 scoped_ptr
<WebSocketStream
> stream
= CreateAndInitializeStream(
175 std::vector
<std::string
>(), "",
176 "Sec-WebSocket-Extensions: permessage-deflate;"
177 " client_max_window_bits=14; server_max_window_bits=14;"
178 " server_no_context_takeover; client_no_context_takeover\r\n");
181 "permessage-deflate;"
182 " client_max_window_bits=14; server_max_window_bits=14;"
183 " server_no_context_takeover; client_no_context_takeover",
184 stream
->GetExtensions());