Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / net / ssl / channel_id_service.cc
blob57f9ef90dca28faa7a25ec5441ac2cda9a6ea9b7
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/ssl/channel_id_service.h"
7 #include <algorithm>
8 #include <limits>
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/callback_helpers.h"
13 #include "base/compiler_specific.h"
14 #include "base/location.h"
15 #include "base/logging.h"
16 #include "base/memory/ref_counted.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/metrics/histogram_macros.h"
19 #include "base/rand_util.h"
20 #include "base/single_thread_task_runner.h"
21 #include "base/stl_util.h"
22 #include "base/task_runner.h"
23 #include "base/thread_task_runner_handle.h"
24 #include "crypto/ec_private_key.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/registry_controlled_domains/registry_controlled_domain.h"
27 #include "net/cert/x509_certificate.h"
28 #include "net/cert/x509_util.h"
29 #include "url/gurl.h"
31 #if !defined(USE_OPENSSL)
32 #include <private/pprthred.h> // PR_DetachThread
33 #endif
35 namespace net {
37 namespace {
39 // Used by the GetDomainBoundCertResult histogram to record the final
40 // outcome of each GetChannelID or GetOrCreateChannelID call.
41 // Do not re-use values.
42 enum GetChannelIDResult {
43 // Synchronously found and returned an existing domain bound cert.
44 SYNC_SUCCESS = 0,
45 // Retrieved or generated and returned a domain bound cert asynchronously.
46 ASYNC_SUCCESS = 1,
47 // Retrieval/generation request was cancelled before the cert generation
48 // completed.
49 ASYNC_CANCELLED = 2,
50 // Cert generation failed.
51 ASYNC_FAILURE_KEYGEN = 3,
52 // Result code 4 was removed (ASYNC_FAILURE_CREATE_CERT)
53 ASYNC_FAILURE_EXPORT_KEY = 5,
54 ASYNC_FAILURE_UNKNOWN = 6,
55 // GetChannelID or GetOrCreateChannelID was called with
56 // invalid arguments.
57 INVALID_ARGUMENT = 7,
58 // We don't support any of the cert types the server requested.
59 UNSUPPORTED_TYPE = 8,
60 // Server asked for a different type of certs while we were generating one.
61 TYPE_MISMATCH = 9,
62 // Couldn't start a worker to generate a cert.
63 WORKER_FAILURE = 10,
64 GET_CHANNEL_ID_RESULT_MAX
67 void RecordGetChannelIDResult(GetChannelIDResult result) {
68 UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.GetDomainBoundCertResult", result,
69 GET_CHANNEL_ID_RESULT_MAX);
72 void RecordGetChannelIDTime(base::TimeDelta request_time) {
73 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTime",
74 request_time,
75 base::TimeDelta::FromMilliseconds(1),
76 base::TimeDelta::FromMinutes(5),
77 50);
80 // On success, returns a ChannelID object and sets |*error| to OK.
81 // Otherwise, returns NULL, and |*error| will be set to a net error code.
82 // |serial_number| is passed in because base::RandInt cannot be called from an
83 // unjoined thread, due to relying on a non-leaked LazyInstance
84 scoped_ptr<ChannelIDStore::ChannelID> GenerateChannelID(
85 const std::string& server_identifier,
86 int* error) {
87 scoped_ptr<ChannelIDStore::ChannelID> result;
89 base::TimeTicks start = base::TimeTicks::Now();
90 base::Time creation_time = base::Time::Now();
91 scoped_ptr<crypto::ECPrivateKey> key(crypto::ECPrivateKey::Create());
93 if (!key) {
94 DLOG(ERROR) << "Unable to create channel ID key pair";
95 *error = ERR_KEY_GENERATION_FAILED;
96 return result.Pass();
99 result.reset(new ChannelIDStore::ChannelID(server_identifier, creation_time,
100 key.Pass()));
101 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GenerateCertTime",
102 base::TimeTicks::Now() - start,
103 base::TimeDelta::FromMilliseconds(1),
104 base::TimeDelta::FromMinutes(5),
105 50);
106 *error = OK;
107 return result.Pass();
110 } // namespace
112 // ChannelIDServiceWorker runs on a worker thread and takes care of the
113 // blocking process of performing key generation. Will take care of deleting
114 // itself once Start() is called.
115 class ChannelIDServiceWorker {
116 public:
117 typedef base::Callback<void(
118 const std::string&,
119 int,
120 scoped_ptr<ChannelIDStore::ChannelID>)> WorkerDoneCallback;
122 ChannelIDServiceWorker(const std::string& server_identifier,
123 const WorkerDoneCallback& callback)
124 : server_identifier_(server_identifier),
125 origin_task_runner_(base::ThreadTaskRunnerHandle::Get()),
126 callback_(callback) {}
128 // Starts the worker on |task_runner|. If the worker fails to start, such as
129 // if the task runner is shutting down, then it will take care of deleting
130 // itself.
131 bool Start(const scoped_refptr<base::TaskRunner>& task_runner) {
132 DCHECK(origin_task_runner_->RunsTasksOnCurrentThread());
134 return task_runner->PostTask(
135 FROM_HERE,
136 base::Bind(&ChannelIDServiceWorker::Run, base::Owned(this)));
139 private:
140 void Run() {
141 // Runs on a worker thread.
142 int error = ERR_FAILED;
143 scoped_ptr<ChannelIDStore::ChannelID> channel_id =
144 GenerateChannelID(server_identifier_, &error);
145 #if !defined(USE_OPENSSL)
146 // Detach the thread from NSPR.
147 // Calling NSS functions attaches the thread to NSPR, which stores
148 // the NSPR thread ID in thread-specific data.
149 // The threads in our thread pool terminate after we have called
150 // PR_Cleanup. Unless we detach them from NSPR, net_unittests gets
151 // segfaults on shutdown when the threads' thread-specific data
152 // destructors run.
153 PR_DetachThread();
154 #endif
155 origin_task_runner_->PostTask(
156 FROM_HERE, base::Bind(callback_, server_identifier_, error,
157 base::Passed(&channel_id)));
160 const std::string server_identifier_;
161 scoped_refptr<base::SequencedTaskRunner> origin_task_runner_;
162 WorkerDoneCallback callback_;
164 DISALLOW_COPY_AND_ASSIGN(ChannelIDServiceWorker);
167 // A ChannelIDServiceJob is a one-to-one counterpart of an
168 // ChannelIDServiceWorker. It lives only on the ChannelIDService's
169 // origin task runner's thread.
170 class ChannelIDServiceJob {
171 public:
172 ChannelIDServiceJob(bool create_if_missing)
173 : create_if_missing_(create_if_missing) {
176 ~ChannelIDServiceJob() { DCHECK(requests_.empty()); }
178 void AddRequest(ChannelIDService::Request* request,
179 bool create_if_missing = false) {
180 create_if_missing_ |= create_if_missing;
181 requests_.push_back(request);
184 void HandleResult(int error, scoped_ptr<crypto::ECPrivateKey> key) {
185 PostAll(error, key.Pass());
188 bool CreateIfMissing() const { return create_if_missing_; }
190 void CancelRequest(ChannelIDService::Request* req) {
191 auto it = std::find(requests_.begin(), requests_.end(), req);
192 if (it != requests_.end())
193 requests_.erase(it);
196 private:
197 void PostAll(int error, scoped_ptr<crypto::ECPrivateKey> key) {
198 std::vector<ChannelIDService::Request*> requests;
199 requests_.swap(requests);
201 for (std::vector<ChannelIDService::Request*>::iterator i = requests.begin();
202 i != requests.end(); i++) {
203 scoped_ptr<crypto::ECPrivateKey> key_copy;
204 if (key)
205 key_copy.reset(key->Copy());
206 (*i)->Post(error, key_copy.Pass());
210 std::vector<ChannelIDService::Request*> requests_;
211 bool create_if_missing_;
214 // static
215 const char ChannelIDService::kEPKIPassword[] = "";
217 ChannelIDService::Request::Request() : service_(NULL) {
220 ChannelIDService::Request::~Request() {
221 Cancel();
224 void ChannelIDService::Request::Cancel() {
225 if (service_) {
226 RecordGetChannelIDResult(ASYNC_CANCELLED);
227 callback_.Reset();
228 job_->CancelRequest(this);
230 service_ = NULL;
234 void ChannelIDService::Request::RequestStarted(
235 ChannelIDService* service,
236 base::TimeTicks request_start,
237 const CompletionCallback& callback,
238 scoped_ptr<crypto::ECPrivateKey>* key,
239 ChannelIDServiceJob* job) {
240 DCHECK(service_ == NULL);
241 service_ = service;
242 request_start_ = request_start;
243 callback_ = callback;
244 key_ = key;
245 job_ = job;
248 void ChannelIDService::Request::Post(int error,
249 scoped_ptr<crypto::ECPrivateKey> key) {
250 switch (error) {
251 case OK: {
252 base::TimeDelta request_time = base::TimeTicks::Now() - request_start_;
253 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTimeAsync",
254 request_time,
255 base::TimeDelta::FromMilliseconds(1),
256 base::TimeDelta::FromMinutes(5), 50);
257 RecordGetChannelIDTime(request_time);
258 RecordGetChannelIDResult(ASYNC_SUCCESS);
259 break;
261 case ERR_KEY_GENERATION_FAILED:
262 RecordGetChannelIDResult(ASYNC_FAILURE_KEYGEN);
263 break;
264 case ERR_PRIVATE_KEY_EXPORT_FAILED:
265 RecordGetChannelIDResult(ASYNC_FAILURE_EXPORT_KEY);
266 break;
267 case ERR_INSUFFICIENT_RESOURCES:
268 RecordGetChannelIDResult(WORKER_FAILURE);
269 break;
270 default:
271 RecordGetChannelIDResult(ASYNC_FAILURE_UNKNOWN);
272 break;
274 service_ = NULL;
275 DCHECK(!callback_.is_null());
276 if (key)
277 *key_ = key.Pass();
278 // Running the callback might delete |this| (e.g. the callback cleans up
279 // resources created for the request), so we can't touch any of our
280 // members afterwards. Reset callback_ first.
281 base::ResetAndReturn(&callback_).Run(error);
284 ChannelIDService::ChannelIDService(
285 ChannelIDStore* channel_id_store,
286 const scoped_refptr<base::TaskRunner>& task_runner)
287 : channel_id_store_(channel_id_store),
288 task_runner_(task_runner),
289 requests_(0),
290 key_store_hits_(0),
291 inflight_joins_(0),
292 workers_created_(0),
293 weak_ptr_factory_(this) {
296 ChannelIDService::~ChannelIDService() {
297 STLDeleteValues(&inflight_);
300 //static
301 std::string ChannelIDService::GetDomainForHost(const std::string& host) {
302 std::string domain =
303 registry_controlled_domains::GetDomainAndRegistry(
304 host, registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES);
305 if (domain.empty())
306 return host;
307 return domain;
310 int ChannelIDService::GetOrCreateChannelID(
311 const std::string& host,
312 scoped_ptr<crypto::ECPrivateKey>* key,
313 const CompletionCallback& callback,
314 Request* out_req) {
315 DVLOG(1) << __FUNCTION__ << " " << host;
316 DCHECK(CalledOnValidThread());
317 base::TimeTicks request_start = base::TimeTicks::Now();
319 if (callback.is_null() || !key || host.empty()) {
320 RecordGetChannelIDResult(INVALID_ARGUMENT);
321 return ERR_INVALID_ARGUMENT;
324 std::string domain = GetDomainForHost(host);
325 if (domain.empty()) {
326 RecordGetChannelIDResult(INVALID_ARGUMENT);
327 return ERR_INVALID_ARGUMENT;
330 requests_++;
332 // See if a request for the same domain is currently in flight.
333 bool create_if_missing = true;
334 if (JoinToInFlightRequest(request_start, domain, key, create_if_missing,
335 callback, out_req)) {
336 return ERR_IO_PENDING;
339 int err = LookupChannelID(request_start, domain, key, create_if_missing,
340 callback, out_req);
341 if (err == ERR_FILE_NOT_FOUND) {
342 // Sync lookup did not find a valid channel ID. Start generating a new one.
343 workers_created_++;
344 ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
345 domain,
346 base::Bind(&ChannelIDService::GeneratedChannelID,
347 weak_ptr_factory_.GetWeakPtr()));
348 if (!worker->Start(task_runner_)) {
349 // TODO(rkn): Log to the NetLog.
350 LOG(ERROR) << "ChannelIDServiceWorker couldn't be started.";
351 RecordGetChannelIDResult(WORKER_FAILURE);
352 return ERR_INSUFFICIENT_RESOURCES;
354 // We are waiting for key generation. Create a job & request to track it.
355 ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing);
356 inflight_[domain] = job;
358 job->AddRequest(out_req);
359 out_req->RequestStarted(this, request_start, callback, key, job);
360 return ERR_IO_PENDING;
363 return err;
366 int ChannelIDService::GetChannelID(const std::string& host,
367 scoped_ptr<crypto::ECPrivateKey>* key,
368 const CompletionCallback& callback,
369 Request* out_req) {
370 DVLOG(1) << __FUNCTION__ << " " << host;
371 DCHECK(CalledOnValidThread());
372 base::TimeTicks request_start = base::TimeTicks::Now();
374 if (callback.is_null() || !key || host.empty()) {
375 RecordGetChannelIDResult(INVALID_ARGUMENT);
376 return ERR_INVALID_ARGUMENT;
379 std::string domain = GetDomainForHost(host);
380 if (domain.empty()) {
381 RecordGetChannelIDResult(INVALID_ARGUMENT);
382 return ERR_INVALID_ARGUMENT;
385 requests_++;
387 // See if a request for the same domain currently in flight.
388 bool create_if_missing = false;
389 if (JoinToInFlightRequest(request_start, domain, key, create_if_missing,
390 callback, out_req)) {
391 return ERR_IO_PENDING;
394 int err = LookupChannelID(request_start, domain, key, create_if_missing,
395 callback, out_req);
396 return err;
399 void ChannelIDService::GotChannelID(int err,
400 const std::string& server_identifier,
401 scoped_ptr<crypto::ECPrivateKey> key) {
402 DCHECK(CalledOnValidThread());
404 std::map<std::string, ChannelIDServiceJob*>::iterator j;
405 j = inflight_.find(server_identifier);
406 if (j == inflight_.end()) {
407 NOTREACHED();
408 return;
411 if (err == OK) {
412 // Async DB lookup found a valid channel ID.
413 key_store_hits_++;
414 // ChannelIDService::Request::Post will do the histograms and stuff.
415 HandleResult(OK, server_identifier, key.Pass());
416 return;
418 // Async lookup failed or the channel ID was missing. Return the error
419 // directly, unless the channel ID was missing and a request asked to create
420 // one.
421 if (err != ERR_FILE_NOT_FOUND || !j->second->CreateIfMissing()) {
422 HandleResult(err, server_identifier, key.Pass());
423 return;
425 // At least one request asked to create a channel ID => start generating a new
426 // one.
427 workers_created_++;
428 ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
429 server_identifier,
430 base::Bind(&ChannelIDService::GeneratedChannelID,
431 weak_ptr_factory_.GetWeakPtr()));
432 if (!worker->Start(task_runner_)) {
433 // TODO(rkn): Log to the NetLog.
434 LOG(ERROR) << "ChannelIDServiceWorker couldn't be started.";
435 HandleResult(ERR_INSUFFICIENT_RESOURCES, server_identifier, nullptr);
439 ChannelIDStore* ChannelIDService::GetChannelIDStore() {
440 return channel_id_store_.get();
443 void ChannelIDService::GeneratedChannelID(
444 const std::string& server_identifier,
445 int error,
446 scoped_ptr<ChannelIDStore::ChannelID> channel_id) {
447 DCHECK(CalledOnValidThread());
449 scoped_ptr<crypto::ECPrivateKey> key;
450 if (error == OK) {
451 key.reset(channel_id->key()->Copy());
452 channel_id_store_->SetChannelID(channel_id.Pass());
454 HandleResult(error, server_identifier, key.Pass());
457 void ChannelIDService::HandleResult(int error,
458 const std::string& server_identifier,
459 scoped_ptr<crypto::ECPrivateKey> key) {
460 DCHECK(CalledOnValidThread());
462 std::map<std::string, ChannelIDServiceJob*>::iterator j;
463 j = inflight_.find(server_identifier);
464 if (j == inflight_.end()) {
465 NOTREACHED();
466 return;
468 ChannelIDServiceJob* job = j->second;
469 inflight_.erase(j);
471 job->HandleResult(error, key.Pass());
472 delete job;
475 bool ChannelIDService::JoinToInFlightRequest(
476 const base::TimeTicks& request_start,
477 const std::string& domain,
478 scoped_ptr<crypto::ECPrivateKey>* key,
479 bool create_if_missing,
480 const CompletionCallback& callback,
481 Request* out_req) {
482 ChannelIDServiceJob* job = NULL;
483 std::map<std::string, ChannelIDServiceJob*>::const_iterator j =
484 inflight_.find(domain);
485 if (j != inflight_.end()) {
486 // A request for the same domain is in flight already. We'll attach our
487 // callback, but we'll also mark it as requiring a channel ID if one's
488 // mising.
489 job = j->second;
490 inflight_joins_++;
492 job->AddRequest(out_req, create_if_missing);
493 out_req->RequestStarted(this, request_start, callback, key, job);
494 return true;
496 return false;
499 int ChannelIDService::LookupChannelID(const base::TimeTicks& request_start,
500 const std::string& domain,
501 scoped_ptr<crypto::ECPrivateKey>* key,
502 bool create_if_missing,
503 const CompletionCallback& callback,
504 Request* out_req) {
505 // Check if a channel ID key already exists for this domain.
506 int err = channel_id_store_->GetChannelID(
507 domain, key, base::Bind(&ChannelIDService::GotChannelID,
508 weak_ptr_factory_.GetWeakPtr()));
510 if (err == OK) {
511 // Sync lookup found a valid channel ID.
512 DVLOG(1) << "Channel ID store had valid key for " << domain;
513 key_store_hits_++;
514 RecordGetChannelIDResult(SYNC_SUCCESS);
515 base::TimeDelta request_time = base::TimeTicks::Now() - request_start;
516 UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time);
517 RecordGetChannelIDTime(request_time);
518 return OK;
521 if (err == ERR_IO_PENDING) {
522 // We are waiting for async DB lookup. Create a job & request to track it.
523 ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing);
524 inflight_[domain] = job;
526 job->AddRequest(out_req);
527 out_req->RequestStarted(this, request_start, callback, key, job);
528 return ERR_IO_PENDING;
531 return err;
534 int ChannelIDService::channel_id_count() {
535 return channel_id_store_->GetChannelIDCount();
538 } // namespace net