1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback_helpers.h"
10 #include "base/logging.h"
11 #include "crypto/secure_util.h"
12 #include "net/base/host_port_pair.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
15 #include "net/cert/cert_status_flags.h"
16 #include "net/cert/cert_verifier.h"
17 #include "net/cert/cert_verify_result.h"
18 #include "net/cert/x509_certificate.h"
19 #include "net/http/transport_security_state.h"
20 #include "net/socket/client_socket_handle.h"
21 #include "net/socket/ssl_client_socket.h"
22 #include "net/socket/ssl_server_socket.h"
23 #include "net/ssl/ssl_config_service.h"
24 #include "remoting/base/rsa_key_pair.h"
25 #include "remoting/protocol/auth_util.h"
26 #include "remoting/protocol/p2p_stream_socket.h"
29 #include "net/socket/ssl_client_socket_openssl.h"
31 #include "net/socket/client_socket_factory.h"
39 // A CertVerifier which rejects every certificate.
40 class FailingCertVerifier
: public net::CertVerifier
{
42 FailingCertVerifier() {}
43 ~FailingCertVerifier() override
{}
45 int Verify(net::X509Certificate
* cert
,
46 const std::string
& hostname
,
47 const std::string
& ocsp_response
,
50 net::CertVerifyResult
* verify_result
,
51 const net::CompletionCallback
& callback
,
52 scoped_ptr
<Request
>* out_req
,
53 const net::BoundNetLog
& net_log
) override
{
54 verify_result
->verified_cert
= cert
;
55 verify_result
->cert_status
= net::CERT_STATUS_INVALID
;
56 return net::ERR_CERT_INVALID
;
60 // Implements net::StreamSocket interface on top of P2PStreamSocket to be passed
61 // to net::SSLClientSocket and net::SSLServerSocket.
62 class NetStreamSocketAdapter
: public net::StreamSocket
{
64 NetStreamSocketAdapter(scoped_ptr
<P2PStreamSocket
> socket
)
65 : socket_(socket
.Pass()) {}
66 ~NetStreamSocketAdapter() override
{}
68 int Read(net::IOBuffer
* buf
, int buf_len
,
69 const net::CompletionCallback
& callback
) override
{
70 return socket_
->Read(buf
, buf_len
, callback
);
72 int Write(net::IOBuffer
* buf
, int buf_len
,
73 const net::CompletionCallback
& callback
) override
{
74 return socket_
->Write(buf
, buf_len
, callback
);
77 int SetReceiveBufferSize(int32_t size
) override
{
79 return net::ERR_FAILED
;
82 int SetSendBufferSize(int32_t size
) override
{
84 return net::ERR_FAILED
;
87 int Connect(const net::CompletionCallback
& callback
) override
{
89 return net::ERR_FAILED
;
91 void Disconnect() override
{ socket_
.reset(); }
92 bool IsConnected() const override
{ return true; }
93 bool IsConnectedAndIdle() const override
{ return true; }
94 int GetPeerAddress(net::IPEndPoint
* address
) const override
{
95 // SSL sockets call this function so it must return some result.
96 net::IPAddressNumber
ip_address(net::kIPv4AddressSize
);
97 *address
= net::IPEndPoint(ip_address
, 0);
100 int GetLocalAddress(net::IPEndPoint
* address
) const override
{
102 return net::ERR_FAILED
;
104 const net::BoundNetLog
& NetLog() const override
{ return net_log_
; }
105 void SetSubresourceSpeculation() override
{ NOTREACHED(); }
106 void SetOmniboxSpeculation() override
{ NOTREACHED(); }
107 bool WasEverUsed() const override
{
111 bool UsingTCPFastOpen() const override
{
115 void EnableTCPFastOpenIfSupported() override
{ NOTREACHED(); }
116 bool WasNpnNegotiated() const override
{
120 net::NextProto
GetNegotiatedProtocol() const override
{
122 return net::kProtoUnknown
;
124 bool GetSSLInfo(net::SSLInfo
* ssl_info
) override
{
128 void GetConnectionAttempts(net::ConnectionAttempts
* out
) const override
{
131 void ClearConnectionAttempts() override
{ NOTREACHED(); }
132 void AddConnectionAttempts(const net::ConnectionAttempts
& attempts
) override
{
137 scoped_ptr
<P2PStreamSocket
> socket_
;
138 net::BoundNetLog net_log_
;
141 // Implements P2PStreamSocket interface on top of net::StreamSocket.
142 class P2PStreamSocketAdapter
: public P2PStreamSocket
{
144 P2PStreamSocketAdapter(scoped_ptr
<net::StreamSocket
> socket
)
145 : socket_(socket
.Pass()) {}
146 ~P2PStreamSocketAdapter() override
{}
148 int Read(const scoped_refptr
<net::IOBuffer
>& buf
, int buf_len
,
149 const net::CompletionCallback
& callback
) override
{
150 return socket_
->Read(buf
.get(), buf_len
, callback
);
152 int Write(const scoped_refptr
<net::IOBuffer
>& buf
, int buf_len
,
153 const net::CompletionCallback
& callback
) override
{
154 return socket_
->Write(buf
.get(), buf_len
, callback
);
158 scoped_ptr
<net::StreamSocket
> socket_
;
164 scoped_ptr
<SslHmacChannelAuthenticator
>
165 SslHmacChannelAuthenticator::CreateForClient(
166 const std::string
& remote_cert
,
167 const std::string
& auth_key
) {
168 scoped_ptr
<SslHmacChannelAuthenticator
> result(
169 new SslHmacChannelAuthenticator(auth_key
));
170 result
->remote_cert_
= remote_cert
;
171 return result
.Pass();
174 scoped_ptr
<SslHmacChannelAuthenticator
>
175 SslHmacChannelAuthenticator::CreateForHost(
176 const std::string
& local_cert
,
177 scoped_refptr
<RsaKeyPair
> key_pair
,
178 const std::string
& auth_key
) {
179 scoped_ptr
<SslHmacChannelAuthenticator
> result(
180 new SslHmacChannelAuthenticator(auth_key
));
181 result
->local_cert_
= local_cert
;
182 result
->local_key_pair_
= key_pair
;
183 return result
.Pass();
186 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
187 const std::string
& auth_key
)
188 : auth_key_(auth_key
) {
191 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
194 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
195 scoped_ptr
<P2PStreamSocket
> socket
,
196 const DoneCallback
& done_callback
) {
197 DCHECK(CalledOnValidThread());
199 done_callback_
= done_callback
;
202 if (is_ssl_server()) {
204 // Client plugin doesn't use server SSL sockets, and so SSLServerSocket
205 // implementation is not compiled for NaCl as part of net_nacl.
207 result
= net::ERR_FAILED
;
209 scoped_refptr
<net::X509Certificate
> cert
=
210 net::X509Certificate::CreateFromBytes(
211 local_cert_
.data(), local_cert_
.length());
213 LOG(ERROR
) << "Failed to parse X509Certificate";
214 NotifyError(net::ERR_FAILED
);
218 net::SSLConfig ssl_config
;
219 ssl_config
.require_ecdhe
= true;
221 scoped_ptr
<net::SSLServerSocket
> server_socket
= net::CreateSSLServerSocket(
222 make_scoped_ptr(new NetStreamSocketAdapter(socket
.Pass())), cert
.get(),
223 local_key_pair_
->private_key(), ssl_config
);
224 net::SSLServerSocket
* raw_server_socket
= server_socket
.get();
225 socket_
= server_socket
.Pass();
226 result
= raw_server_socket
->Handshake(
227 base::Bind(&SslHmacChannelAuthenticator::OnConnected
,
228 base::Unretained(this)));
231 transport_security_state_
.reset(new net::TransportSecurityState
);
232 cert_verifier_
.reset(new FailingCertVerifier
);
234 net::SSLConfig::CertAndStatus cert_and_status
;
235 cert_and_status
.cert_status
= net::CERT_STATUS_AUTHORITY_INVALID
;
236 cert_and_status
.der_cert
= remote_cert_
;
238 net::SSLConfig ssl_config
;
239 // Certificate verification and revocation checking are not needed
240 // because we use self-signed certs. Disable it so that the SSL
241 // layer doesn't try to initialize OCSP (OCSP works only on the IO
243 ssl_config
.cert_io_enabled
= false;
244 ssl_config
.rev_checking_enabled
= false;
245 ssl_config
.allowed_bad_certs
.push_back(cert_and_status
);
246 ssl_config
.require_ecdhe
= true;
248 net::HostPortPair
host_and_port(kSslFakeHostName
, 0);
249 net::SSLClientSocketContext context
;
250 context
.transport_security_state
= transport_security_state_
.get();
251 context
.cert_verifier
= cert_verifier_
.get();
252 scoped_ptr
<net::ClientSocketHandle
> socket_handle(
253 new net::ClientSocketHandle
);
254 socket_handle
->SetSocket(
255 make_scoped_ptr(new NetStreamSocketAdapter(socket
.Pass())));
258 // net_nacl doesn't include ClientSocketFactory.
259 socket_
.reset(new net::SSLClientSocketOpenSSL(
260 socket_handle
.Pass(), host_and_port
, ssl_config
, context
));
263 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
264 socket_handle
.Pass(), host_and_port
, ssl_config
, context
);
267 result
= socket_
->Connect(
268 base::Bind(&SslHmacChannelAuthenticator::OnConnected
,
269 base::Unretained(this)));
272 if (result
== net::ERR_IO_PENDING
)
278 bool SslHmacChannelAuthenticator::is_ssl_server() {
279 return local_key_pair_
.get() != nullptr;
282 void SslHmacChannelAuthenticator::OnConnected(int result
) {
283 if (result
!= net::OK
) {
284 LOG(WARNING
) << "Failed to establish SSL connection. Error: "
285 << net::ErrorToString(result
);
290 // Generate authentication digest to write to the socket.
291 std::string auth_bytes
= GetAuthBytes(
292 socket_
.get(), is_ssl_server() ?
293 kHostAuthSslExporterLabel
: kClientAuthSslExporterLabel
, auth_key_
);
294 if (auth_bytes
.empty()) {
295 NotifyError(net::ERR_FAILED
);
299 // Allocate a buffer to write the digest.
300 auth_write_buf_
= new net::DrainableIOBuffer(
301 new net::StringIOBuffer(auth_bytes
), auth_bytes
.size());
303 // Read an incoming token.
304 auth_read_buf_
= new net::GrowableIOBuffer();
305 auth_read_buf_
->SetCapacity(kAuthDigestLength
);
307 // If WriteAuthenticationBytes() results in |done_callback_| being
308 // called then we must not do anything else because this object may
309 // be destroyed at that point.
310 bool callback_called
= false;
311 WriteAuthenticationBytes(&callback_called
);
312 if (!callback_called
)
313 ReadAuthenticationBytes();
316 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
317 bool* callback_called
) {
319 int result
= socket_
->Write(
320 auth_write_buf_
.get(),
321 auth_write_buf_
->BytesRemaining(),
322 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten
,
323 base::Unretained(this)));
324 if (result
== net::ERR_IO_PENDING
)
326 if (!HandleAuthBytesWritten(result
, callback_called
))
331 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result
) {
332 DCHECK(CalledOnValidThread());
334 if (HandleAuthBytesWritten(result
, nullptr))
335 WriteAuthenticationBytes(nullptr);
338 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
339 int result
, bool* callback_called
) {
341 LOG(ERROR
) << "Error writing authentication: " << result
;
343 *callback_called
= false;
348 auth_write_buf_
->DidConsume(result
);
349 if (auth_write_buf_
->BytesRemaining() > 0)
352 auth_write_buf_
= nullptr;
353 CheckDone(callback_called
);
357 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
360 socket_
->Read(auth_read_buf_
.get(),
361 auth_read_buf_
->RemainingCapacity(),
362 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead
,
363 base::Unretained(this)));
364 if (result
== net::ERR_IO_PENDING
)
366 if (!HandleAuthBytesRead(result
))
371 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result
) {
372 DCHECK(CalledOnValidThread());
374 if (HandleAuthBytesRead(result
))
375 ReadAuthenticationBytes();
378 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result
) {
379 if (read_result
<= 0) {
380 NotifyError(read_result
);
384 auth_read_buf_
->set_offset(auth_read_buf_
->offset() + read_result
);
385 if (auth_read_buf_
->RemainingCapacity() > 0)
388 if (!VerifyAuthBytes(std::string(
389 auth_read_buf_
->StartOfBuffer(),
390 auth_read_buf_
->StartOfBuffer() + kAuthDigestLength
))) {
391 LOG(WARNING
) << "Mismatched authentication";
392 NotifyError(net::ERR_FAILED
);
396 auth_read_buf_
= nullptr;
401 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
402 const std::string
& received_auth_bytes
) {
403 DCHECK(received_auth_bytes
.length() == kAuthDigestLength
);
405 // Compute expected auth bytes.
406 std::string auth_bytes
= GetAuthBytes(
407 socket_
.get(), is_ssl_server() ?
408 kClientAuthSslExporterLabel
: kHostAuthSslExporterLabel
, auth_key_
);
409 if (auth_bytes
.empty())
412 return crypto::SecureMemEqual(received_auth_bytes
.data(),
413 &(auth_bytes
[0]), kAuthDigestLength
);
416 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called
) {
417 if (auth_write_buf_
.get() == nullptr && auth_read_buf_
.get() == nullptr) {
418 DCHECK(socket_
.get() != nullptr);
420 *callback_called
= true;
422 base::ResetAndReturn(&done_callback_
)
424 make_scoped_ptr(new P2PStreamSocketAdapter(socket_
.Pass())));
428 void SslHmacChannelAuthenticator::NotifyError(int error
) {
429 base::ResetAndReturn(&done_callback_
).Run(error
, nullptr);
432 } // namespace protocol
433 } // namespace remoting