Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / remoting / protocol / ssl_hmac_channel_authenticator.cc
blob5658d5cff17c16b453c000334b518cc367c0656d
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
7 #include "base/bind.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback_helpers.h"
10 #include "base/logging.h"
11 #include "crypto/secure_util.h"
12 #include "net/base/host_port_pair.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
15 #include "net/cert/cert_status_flags.h"
16 #include "net/cert/cert_verifier.h"
17 #include "net/cert/cert_verify_result.h"
18 #include "net/cert/x509_certificate.h"
19 #include "net/http/transport_security_state.h"
20 #include "net/socket/client_socket_handle.h"
21 #include "net/socket/ssl_client_socket.h"
22 #include "net/socket/ssl_server_socket.h"
23 #include "net/ssl/ssl_config_service.h"
24 #include "remoting/base/rsa_key_pair.h"
25 #include "remoting/protocol/auth_util.h"
26 #include "remoting/protocol/p2p_stream_socket.h"
28 #if defined(OS_NACL)
29 #include "net/socket/ssl_client_socket_openssl.h"
30 #else
31 #include "net/socket/client_socket_factory.h"
32 #endif
34 namespace remoting {
35 namespace protocol {
37 namespace {
39 // A CertVerifier which rejects every certificate.
40 class FailingCertVerifier : public net::CertVerifier {
41 public:
42 FailingCertVerifier() {}
43 ~FailingCertVerifier() override {}
45 int Verify(net::X509Certificate* cert,
46 const std::string& hostname,
47 const std::string& ocsp_response,
48 int flags,
49 net::CRLSet* crl_set,
50 net::CertVerifyResult* verify_result,
51 const net::CompletionCallback& callback,
52 scoped_ptr<Request>* out_req,
53 const net::BoundNetLog& net_log) override {
54 verify_result->verified_cert = cert;
55 verify_result->cert_status = net::CERT_STATUS_INVALID;
56 return net::ERR_CERT_INVALID;
60 // Implements net::StreamSocket interface on top of P2PStreamSocket to be passed
61 // to net::SSLClientSocket and net::SSLServerSocket.
62 class NetStreamSocketAdapter : public net::StreamSocket {
63 public:
64 NetStreamSocketAdapter(scoped_ptr<P2PStreamSocket> socket)
65 : socket_(socket.Pass()) {}
66 ~NetStreamSocketAdapter() override {}
68 int Read(net::IOBuffer* buf, int buf_len,
69 const net::CompletionCallback& callback) override {
70 return socket_->Read(buf, buf_len, callback);
72 int Write(net::IOBuffer* buf, int buf_len,
73 const net::CompletionCallback& callback) override {
74 return socket_->Write(buf, buf_len, callback);
77 int SetReceiveBufferSize(int32_t size) override {
78 NOTREACHED();
79 return net::ERR_FAILED;
82 int SetSendBufferSize(int32_t size) override {
83 NOTREACHED();
84 return net::ERR_FAILED;
87 int Connect(const net::CompletionCallback& callback) override {
88 NOTREACHED();
89 return net::ERR_FAILED;
91 void Disconnect() override { socket_.reset(); }
92 bool IsConnected() const override { return true; }
93 bool IsConnectedAndIdle() const override { return true; }
94 int GetPeerAddress(net::IPEndPoint* address) const override {
95 // SSL sockets call this function so it must return some result.
96 net::IPAddressNumber ip_address(net::kIPv4AddressSize);
97 *address = net::IPEndPoint(ip_address, 0);
98 return net::OK;
100 int GetLocalAddress(net::IPEndPoint* address) const override {
101 NOTREACHED();
102 return net::ERR_FAILED;
104 const net::BoundNetLog& NetLog() const override { return net_log_; }
105 void SetSubresourceSpeculation() override { NOTREACHED(); }
106 void SetOmniboxSpeculation() override { NOTREACHED(); }
107 bool WasEverUsed() const override {
108 NOTREACHED();
109 return true;
111 bool UsingTCPFastOpen() const override {
112 NOTREACHED();
113 return false;
115 void EnableTCPFastOpenIfSupported() override { NOTREACHED(); }
116 bool WasNpnNegotiated() const override {
117 NOTREACHED();
118 return false;
120 net::NextProto GetNegotiatedProtocol() const override {
121 NOTREACHED();
122 return net::kProtoUnknown;
124 bool GetSSLInfo(net::SSLInfo* ssl_info) override {
125 NOTREACHED();
126 return false;
128 void GetConnectionAttempts(net::ConnectionAttempts* out) const override {
129 NOTREACHED();
131 void ClearConnectionAttempts() override { NOTREACHED(); }
132 void AddConnectionAttempts(const net::ConnectionAttempts& attempts) override {
133 NOTREACHED();
136 private:
137 scoped_ptr<P2PStreamSocket> socket_;
138 net::BoundNetLog net_log_;
141 // Implements P2PStreamSocket interface on top of net::StreamSocket.
142 class P2PStreamSocketAdapter : public P2PStreamSocket {
143 public:
144 P2PStreamSocketAdapter(scoped_ptr<net::StreamSocket> socket)
145 : socket_(socket.Pass()) {}
146 ~P2PStreamSocketAdapter() override {}
148 int Read(const scoped_refptr<net::IOBuffer>& buf, int buf_len,
149 const net::CompletionCallback& callback) override {
150 return socket_->Read(buf.get(), buf_len, callback);
152 int Write(const scoped_refptr<net::IOBuffer>& buf, int buf_len,
153 const net::CompletionCallback& callback) override {
154 return socket_->Write(buf.get(), buf_len, callback);
157 private:
158 scoped_ptr<net::StreamSocket> socket_;
161 } // namespace
163 // static
164 scoped_ptr<SslHmacChannelAuthenticator>
165 SslHmacChannelAuthenticator::CreateForClient(
166 const std::string& remote_cert,
167 const std::string& auth_key) {
168 scoped_ptr<SslHmacChannelAuthenticator> result(
169 new SslHmacChannelAuthenticator(auth_key));
170 result->remote_cert_ = remote_cert;
171 return result.Pass();
174 scoped_ptr<SslHmacChannelAuthenticator>
175 SslHmacChannelAuthenticator::CreateForHost(
176 const std::string& local_cert,
177 scoped_refptr<RsaKeyPair> key_pair,
178 const std::string& auth_key) {
179 scoped_ptr<SslHmacChannelAuthenticator> result(
180 new SslHmacChannelAuthenticator(auth_key));
181 result->local_cert_ = local_cert;
182 result->local_key_pair_ = key_pair;
183 return result.Pass();
186 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
187 const std::string& auth_key)
188 : auth_key_(auth_key) {
191 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
194 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
195 scoped_ptr<P2PStreamSocket> socket,
196 const DoneCallback& done_callback) {
197 DCHECK(CalledOnValidThread());
199 done_callback_ = done_callback;
201 int result;
202 if (is_ssl_server()) {
203 #if defined(OS_NACL)
204 // Client plugin doesn't use server SSL sockets, and so SSLServerSocket
205 // implementation is not compiled for NaCl as part of net_nacl.
206 NOTREACHED();
207 result = net::ERR_FAILED;
208 #else
209 scoped_refptr<net::X509Certificate> cert =
210 net::X509Certificate::CreateFromBytes(
211 local_cert_.data(), local_cert_.length());
212 if (!cert.get()) {
213 LOG(ERROR) << "Failed to parse X509Certificate";
214 NotifyError(net::ERR_FAILED);
215 return;
218 net::SSLConfig ssl_config;
219 ssl_config.require_ecdhe = true;
221 scoped_ptr<net::SSLServerSocket> server_socket = net::CreateSSLServerSocket(
222 make_scoped_ptr(new NetStreamSocketAdapter(socket.Pass())), cert.get(),
223 local_key_pair_->private_key(), ssl_config);
224 net::SSLServerSocket* raw_server_socket = server_socket.get();
225 socket_ = server_socket.Pass();
226 result = raw_server_socket->Handshake(
227 base::Bind(&SslHmacChannelAuthenticator::OnConnected,
228 base::Unretained(this)));
229 #endif
230 } else {
231 transport_security_state_.reset(new net::TransportSecurityState);
232 cert_verifier_.reset(new FailingCertVerifier);
234 net::SSLConfig::CertAndStatus cert_and_status;
235 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
236 cert_and_status.der_cert = remote_cert_;
238 net::SSLConfig ssl_config;
239 // Certificate verification and revocation checking are not needed
240 // because we use self-signed certs. Disable it so that the SSL
241 // layer doesn't try to initialize OCSP (OCSP works only on the IO
242 // thread).
243 ssl_config.cert_io_enabled = false;
244 ssl_config.rev_checking_enabled = false;
245 ssl_config.allowed_bad_certs.push_back(cert_and_status);
246 ssl_config.require_ecdhe = false;
248 net::HostPortPair host_and_port(kSslFakeHostName, 0);
249 net::SSLClientSocketContext context;
250 context.transport_security_state = transport_security_state_.get();
251 context.cert_verifier = cert_verifier_.get();
252 scoped_ptr<net::ClientSocketHandle> socket_handle(
253 new net::ClientSocketHandle);
254 socket_handle->SetSocket(
255 make_scoped_ptr(new NetStreamSocketAdapter(socket.Pass())));
257 #if defined(OS_NACL)
258 // net_nacl doesn't include ClientSocketFactory.
259 socket_.reset(new net::SSLClientSocketOpenSSL(
260 socket_handle.Pass(), host_and_port, ssl_config, context));
261 #else
262 socket_ =
263 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
264 socket_handle.Pass(), host_and_port, ssl_config, context);
265 #endif
267 result = socket_->Connect(
268 base::Bind(&SslHmacChannelAuthenticator::OnConnected,
269 base::Unretained(this)));
272 if (result == net::ERR_IO_PENDING)
273 return;
275 OnConnected(result);
278 bool SslHmacChannelAuthenticator::is_ssl_server() {
279 return local_key_pair_.get() != nullptr;
282 void SslHmacChannelAuthenticator::OnConnected(int result) {
283 if (result != net::OK) {
284 LOG(WARNING) << "Failed to establish SSL connection. Error: "
285 << net::ErrorToString(result);
286 NotifyError(result);
287 return;
290 // Generate authentication digest to write to the socket.
291 std::string auth_bytes = GetAuthBytes(
292 socket_.get(), is_ssl_server() ?
293 kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
294 if (auth_bytes.empty()) {
295 NotifyError(net::ERR_FAILED);
296 return;
299 // Allocate a buffer to write the digest.
300 auth_write_buf_ = new net::DrainableIOBuffer(
301 new net::StringIOBuffer(auth_bytes), auth_bytes.size());
303 // Read an incoming token.
304 auth_read_buf_ = new net::GrowableIOBuffer();
305 auth_read_buf_->SetCapacity(kAuthDigestLength);
307 // If WriteAuthenticationBytes() results in |done_callback_| being
308 // called then we must not do anything else because this object may
309 // be destroyed at that point.
310 bool callback_called = false;
311 WriteAuthenticationBytes(&callback_called);
312 if (!callback_called)
313 ReadAuthenticationBytes();
316 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
317 bool* callback_called) {
318 while (true) {
319 int result = socket_->Write(
320 auth_write_buf_.get(),
321 auth_write_buf_->BytesRemaining(),
322 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
323 base::Unretained(this)));
324 if (result == net::ERR_IO_PENDING)
325 break;
326 if (!HandleAuthBytesWritten(result, callback_called))
327 break;
331 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
332 DCHECK(CalledOnValidThread());
334 if (HandleAuthBytesWritten(result, nullptr))
335 WriteAuthenticationBytes(nullptr);
338 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
339 int result, bool* callback_called) {
340 if (result <= 0) {
341 LOG(ERROR) << "Error writing authentication: " << result;
342 if (callback_called)
343 *callback_called = false;
344 NotifyError(result);
345 return false;
348 auth_write_buf_->DidConsume(result);
349 if (auth_write_buf_->BytesRemaining() > 0)
350 return true;
352 auth_write_buf_ = nullptr;
353 CheckDone(callback_called);
354 return false;
357 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
358 while (true) {
359 int result =
360 socket_->Read(auth_read_buf_.get(),
361 auth_read_buf_->RemainingCapacity(),
362 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
363 base::Unretained(this)));
364 if (result == net::ERR_IO_PENDING)
365 break;
366 if (!HandleAuthBytesRead(result))
367 break;
371 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
372 DCHECK(CalledOnValidThread());
374 if (HandleAuthBytesRead(result))
375 ReadAuthenticationBytes();
378 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
379 if (read_result <= 0) {
380 NotifyError(read_result);
381 return false;
384 auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
385 if (auth_read_buf_->RemainingCapacity() > 0)
386 return true;
388 if (!VerifyAuthBytes(std::string(
389 auth_read_buf_->StartOfBuffer(),
390 auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
391 LOG(WARNING) << "Mismatched authentication";
392 NotifyError(net::ERR_FAILED);
393 return false;
396 auth_read_buf_ = nullptr;
397 CheckDone(nullptr);
398 return false;
401 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
402 const std::string& received_auth_bytes) {
403 DCHECK(received_auth_bytes.length() == kAuthDigestLength);
405 // Compute expected auth bytes.
406 std::string auth_bytes = GetAuthBytes(
407 socket_.get(), is_ssl_server() ?
408 kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
409 if (auth_bytes.empty())
410 return false;
412 return crypto::SecureMemEqual(received_auth_bytes.data(),
413 &(auth_bytes[0]), kAuthDigestLength);
416 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
417 if (auth_write_buf_.get() == nullptr && auth_read_buf_.get() == nullptr) {
418 DCHECK(socket_.get() != nullptr);
419 if (callback_called)
420 *callback_called = true;
422 base::ResetAndReturn(&done_callback_)
423 .Run(net::OK,
424 make_scoped_ptr(new P2PStreamSocketAdapter(socket_.Pass())));
428 void SslHmacChannelAuthenticator::NotifyError(int error) {
429 base::ResetAndReturn(&done_callback_).Run(error, nullptr);
432 } // namespace protocol
433 } // namespace remoting