1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "mojo/services/network/web_socket_impl.h"
7 #include "base/logging.h"
8 #include "base/message_loop/message_loop.h"
9 #include "mojo/common/handle_watcher.h"
10 #include "mojo/services/network/network_context.h"
11 #include "mojo/services/network/public/cpp/web_socket_read_queue.h"
12 #include "mojo/services/network/public/cpp/web_socket_write_queue.h"
13 #include "net/websockets/websocket_channel.h"
14 #include "net/websockets/websocket_errors.h"
15 #include "net/websockets/websocket_event_interface.h"
16 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
17 #include "net/websockets/websocket_handshake_request_info.h"
18 #include "net/websockets/websocket_handshake_response_info.h"
19 #include "url/origin.h"
24 struct TypeConverter
<net::WebSocketFrameHeader::OpCode
,
25 WebSocket::MessageType
> {
26 static net::WebSocketFrameHeader::OpCode
Convert(
27 WebSocket::MessageType type
) {
28 DCHECK(type
== WebSocket::MESSAGE_TYPE_CONTINUATION
||
29 type
== WebSocket::MESSAGE_TYPE_TEXT
||
30 type
== WebSocket::MESSAGE_TYPE_BINARY
);
31 typedef net::WebSocketFrameHeader::OpCode OpCode
;
32 // These compile asserts verify that the same underlying values are used for
33 // both types, so we can simply cast between them.
34 COMPILE_ASSERT(static_cast<OpCode
>(WebSocket::MESSAGE_TYPE_CONTINUATION
) ==
35 net::WebSocketFrameHeader::kOpCodeContinuation
,
36 enum_values_must_match_for_opcode_continuation
);
37 COMPILE_ASSERT(static_cast<OpCode
>(WebSocket::MESSAGE_TYPE_TEXT
) ==
38 net::WebSocketFrameHeader::kOpCodeText
,
39 enum_values_must_match_for_opcode_text
);
40 COMPILE_ASSERT(static_cast<OpCode
>(WebSocket::MESSAGE_TYPE_BINARY
) ==
41 net::WebSocketFrameHeader::kOpCodeBinary
,
42 enum_values_must_match_for_opcode_binary
);
43 return static_cast<OpCode
>(type
);
48 struct TypeConverter
<WebSocket::MessageType
,
49 net::WebSocketFrameHeader::OpCode
> {
50 static WebSocket::MessageType
Convert(
51 net::WebSocketFrameHeader::OpCode type
) {
52 DCHECK(type
== net::WebSocketFrameHeader::kOpCodeContinuation
||
53 type
== net::WebSocketFrameHeader::kOpCodeText
||
54 type
== net::WebSocketFrameHeader::kOpCodeBinary
);
55 return static_cast<WebSocket::MessageType
>(type
);
61 typedef net::WebSocketEventInterface::ChannelState ChannelState
;
63 struct WebSocketEventHandler
: public net::WebSocketEventInterface
{
65 WebSocketEventHandler(WebSocketClientPtr client
)
66 : client_(client
.Pass()) {
68 ~WebSocketEventHandler() override
{}
71 // net::WebSocketEventInterface methods:
72 ChannelState
OnAddChannelResponse(const std::string
& selected_subprotocol
,
73 const std::string
& extensions
) override
;
74 ChannelState
OnDataFrame(bool fin
,
75 WebSocketMessageType type
,
76 const std::vector
<char>& data
) override
;
77 ChannelState
OnClosingHandshake() override
;
78 ChannelState
OnFlowControl(int64 quota
) override
;
79 ChannelState
OnDropChannel(bool was_clean
,
81 const std::string
& reason
) override
;
82 ChannelState
OnFailChannel(const std::string
& message
) override
;
83 ChannelState
OnStartOpeningHandshake(
84 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) override
;
85 ChannelState
OnFinishOpeningHandshake(
86 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) override
;
87 ChannelState
OnSSLCertificateError(
88 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
90 const net::SSLInfo
& ssl_info
,
93 // Called once we've written to |receive_stream_|.
94 void DidWriteToReceiveStream(bool fin
,
95 net::WebSocketFrameHeader::OpCode type
,
98 WebSocketClientPtr client_
;
99 ScopedDataPipeProducerHandle receive_stream_
;
100 scoped_ptr
<WebSocketWriteQueue
> write_queue_
;
102 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler
);
105 ChannelState
WebSocketEventHandler::OnAddChannelResponse(
106 const std::string
& selected_protocol
,
107 const std::string
& extensions
) {
109 receive_stream_
= data_pipe
.producer_handle
.Pass();
110 write_queue_
.reset(new WebSocketWriteQueue(receive_stream_
.get()));
112 selected_protocol
, extensions
, data_pipe
.consumer_handle
.Pass());
113 return WebSocketEventInterface::CHANNEL_ALIVE
;
116 ChannelState
WebSocketEventHandler::OnDataFrame(
118 net::WebSocketFrameHeader::OpCode type
,
119 const std::vector
<char>& data
) {
120 uint32_t size
= static_cast<uint32_t>(data
.size());
123 base::Bind(&WebSocketEventHandler::DidWriteToReceiveStream
,
124 base::Unretained(this),
126 return WebSocketEventInterface::CHANNEL_ALIVE
;
129 ChannelState
WebSocketEventHandler::OnClosingHandshake() {
130 return WebSocketEventInterface::CHANNEL_ALIVE
;
133 ChannelState
WebSocketEventHandler::OnFlowControl(int64 quota
) {
134 client_
->DidReceiveFlowControl(quota
);
135 return WebSocketEventInterface::CHANNEL_ALIVE
;
138 ChannelState
WebSocketEventHandler::OnDropChannel(bool was_clean
,
140 const std::string
& reason
) {
141 client_
->DidClose(was_clean
, code
, reason
);
142 return WebSocketEventInterface::CHANNEL_DELETED
;
145 ChannelState
WebSocketEventHandler::OnFailChannel(const std::string
& message
) {
146 client_
->DidFail(message
);
147 return WebSocketEventInterface::CHANNEL_DELETED
;
150 ChannelState
WebSocketEventHandler::OnStartOpeningHandshake(
151 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) {
152 return WebSocketEventInterface::CHANNEL_ALIVE
;
155 ChannelState
WebSocketEventHandler::OnFinishOpeningHandshake(
156 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) {
157 return WebSocketEventInterface::CHANNEL_ALIVE
;
160 ChannelState
WebSocketEventHandler::OnSSLCertificateError(
161 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
163 const net::SSLInfo
& ssl_info
,
165 client_
->DidFail("SSL Error");
166 return WebSocketEventInterface::CHANNEL_DELETED
;
169 void WebSocketEventHandler::DidWriteToReceiveStream(
171 net::WebSocketFrameHeader::OpCode type
,
173 const char* buffer
) {
174 client_
->DidReceiveData(
175 fin
, ConvertTo
<WebSocket::MessageType
>(type
), num_bytes
);
180 WebSocketImpl::WebSocketImpl(
181 NetworkContext
* context
,
182 scoped_ptr
<mojo::AppRefCount
> app_refcount
,
183 InterfaceRequest
<WebSocket
> request
)
184 : context_(context
), app_refcount_(app_refcount
.Pass()),
185 binding_(this, request
.Pass()) {
188 WebSocketImpl::~WebSocketImpl() {
191 void WebSocketImpl::Connect(const String
& url
,
192 Array
<String
> protocols
,
193 const String
& origin
,
194 ScopedDataPipeConsumerHandle send_stream
,
195 WebSocketClientPtr client
) {
197 send_stream_
= send_stream
.Pass();
198 read_queue_
.reset(new WebSocketReadQueue(send_stream_
.get()));
199 scoped_ptr
<net::WebSocketEventInterface
> event_interface(
200 new WebSocketEventHandler(client
.Pass()));
201 channel_
.reset(new net::WebSocketChannel(event_interface
.Pass(),
202 context_
->url_request_context()));
203 channel_
->SendAddChannelRequest(GURL(url
.get()),
204 protocols
.To
<std::vector
<std::string
>>(),
205 url::Origin(GURL(origin
.get())));
208 void WebSocketImpl::Send(bool fin
,
209 WebSocket::MessageType type
,
210 uint32_t num_bytes
) {
212 read_queue_
->Read(num_bytes
,
213 base::Bind(&WebSocketImpl::DidReadFromSendStream
,
214 base::Unretained(this),
215 fin
, type
, num_bytes
));
218 void WebSocketImpl::FlowControl(int64_t quota
) {
220 channel_
->SendFlowControl(quota
);
223 void WebSocketImpl::Close(uint16_t code
, const String
& reason
) {
225 channel_
->StartClosingHandshake(code
, reason
);
228 void WebSocketImpl::DidReadFromSendStream(bool fin
,
229 WebSocket::MessageType type
,
232 std::vector
<char> buffer(num_bytes
);
233 memcpy(&buffer
[0], data
, num_bytes
);
236 fin
, ConvertTo
<net::WebSocketFrameHeader::OpCode
>(type
), buffer
);