Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / net / ssl / default_channel_id_store.cc
blobb0497e1f63184e80c5160fc8db5193f8903c379c
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/ssl/default_channel_id_store.h"
7 #include "base/bind.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/metrics/histogram_macros.h"
10 #include "crypto/ec_private_key.h"
11 #include "net/base/net_errors.h"
13 namespace net {
15 // --------------------------------------------------------------------------
16 // Task
17 class DefaultChannelIDStore::Task {
18 public:
19 virtual ~Task();
21 // Runs the task and invokes the client callback on the thread that
22 // originally constructed the task.
23 virtual void Run(DefaultChannelIDStore* store) = 0;
25 protected:
26 void InvokeCallback(base::Closure callback) const;
29 DefaultChannelIDStore::Task::~Task() {
32 void DefaultChannelIDStore::Task::InvokeCallback(
33 base::Closure callback) const {
34 if (!callback.is_null())
35 callback.Run();
38 // --------------------------------------------------------------------------
39 // GetChannelIDTask
40 class DefaultChannelIDStore::GetChannelIDTask
41 : public DefaultChannelIDStore::Task {
42 public:
43 GetChannelIDTask(const std::string& server_identifier,
44 const GetChannelIDCallback& callback);
45 ~GetChannelIDTask() override;
46 void Run(DefaultChannelIDStore* store) override;
48 private:
49 std::string server_identifier_;
50 GetChannelIDCallback callback_;
53 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
54 const std::string& server_identifier,
55 const GetChannelIDCallback& callback)
56 : server_identifier_(server_identifier),
57 callback_(callback) {
60 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
63 void DefaultChannelIDStore::GetChannelIDTask::Run(
64 DefaultChannelIDStore* store) {
65 scoped_ptr<crypto::ECPrivateKey> key_result;
66 int err = store->GetChannelID(server_identifier_, &key_result,
67 GetChannelIDCallback());
68 DCHECK(err != ERR_IO_PENDING);
70 InvokeCallback(base::Bind(callback_, err, server_identifier_,
71 base::Passed(key_result.Pass())));
74 // --------------------------------------------------------------------------
75 // SetChannelIDTask
76 class DefaultChannelIDStore::SetChannelIDTask
77 : public DefaultChannelIDStore::Task {
78 public:
79 SetChannelIDTask(scoped_ptr<ChannelID> channel_id);
80 ~SetChannelIDTask() override;
81 void Run(DefaultChannelIDStore* store) override;
83 private:
84 scoped_ptr<ChannelID> channel_id_;
87 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
88 scoped_ptr<ChannelID> channel_id)
89 : channel_id_(channel_id.Pass()) {
92 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
95 void DefaultChannelIDStore::SetChannelIDTask::Run(
96 DefaultChannelIDStore* store) {
97 store->SyncSetChannelID(channel_id_.Pass());
100 // --------------------------------------------------------------------------
101 // DeleteChannelIDTask
102 class DefaultChannelIDStore::DeleteChannelIDTask
103 : public DefaultChannelIDStore::Task {
104 public:
105 DeleteChannelIDTask(const std::string& server_identifier,
106 const base::Closure& callback);
107 ~DeleteChannelIDTask() override;
108 void Run(DefaultChannelIDStore* store) override;
110 private:
111 std::string server_identifier_;
112 base::Closure callback_;
115 DefaultChannelIDStore::DeleteChannelIDTask::
116 DeleteChannelIDTask(
117 const std::string& server_identifier,
118 const base::Closure& callback)
119 : server_identifier_(server_identifier),
120 callback_(callback) {
123 DefaultChannelIDStore::DeleteChannelIDTask::
124 ~DeleteChannelIDTask() {
127 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
128 DefaultChannelIDStore* store) {
129 store->SyncDeleteChannelID(server_identifier_);
131 InvokeCallback(callback_);
134 // --------------------------------------------------------------------------
135 // DeleteAllCreatedBetweenTask
136 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
137 : public DefaultChannelIDStore::Task {
138 public:
139 DeleteAllCreatedBetweenTask(base::Time delete_begin,
140 base::Time delete_end,
141 const base::Closure& callback);
142 ~DeleteAllCreatedBetweenTask() override;
143 void Run(DefaultChannelIDStore* store) override;
145 private:
146 base::Time delete_begin_;
147 base::Time delete_end_;
148 base::Closure callback_;
151 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
152 DeleteAllCreatedBetweenTask(
153 base::Time delete_begin,
154 base::Time delete_end,
155 const base::Closure& callback)
156 : delete_begin_(delete_begin),
157 delete_end_(delete_end),
158 callback_(callback) {
161 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
162 ~DeleteAllCreatedBetweenTask() {
165 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
166 DefaultChannelIDStore* store) {
167 store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
169 InvokeCallback(callback_);
172 // --------------------------------------------------------------------------
173 // GetAllChannelIDsTask
174 class DefaultChannelIDStore::GetAllChannelIDsTask
175 : public DefaultChannelIDStore::Task {
176 public:
177 explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
178 ~GetAllChannelIDsTask() override;
179 void Run(DefaultChannelIDStore* store) override;
181 private:
182 std::string server_identifier_;
183 GetChannelIDListCallback callback_;
186 DefaultChannelIDStore::GetAllChannelIDsTask::
187 GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
188 : callback_(callback) {
191 DefaultChannelIDStore::GetAllChannelIDsTask::
192 ~GetAllChannelIDsTask() {
195 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
196 DefaultChannelIDStore* store) {
197 ChannelIDList key_list;
198 store->SyncGetAllChannelIDs(&key_list);
200 InvokeCallback(base::Bind(callback_, key_list));
203 // --------------------------------------------------------------------------
204 // DefaultChannelIDStore
206 DefaultChannelIDStore::DefaultChannelIDStore(
207 PersistentStore* store)
208 : initialized_(false),
209 loaded_(false),
210 store_(store),
211 weak_ptr_factory_(this) {}
213 int DefaultChannelIDStore::GetChannelID(
214 const std::string& server_identifier,
215 scoped_ptr<crypto::ECPrivateKey>* key_result,
216 const GetChannelIDCallback& callback) {
217 DCHECK(CalledOnValidThread());
218 InitIfNecessary();
220 if (!loaded_) {
221 EnqueueTask(scoped_ptr<Task>(
222 new GetChannelIDTask(server_identifier, callback)));
223 return ERR_IO_PENDING;
226 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
228 if (it == channel_ids_.end())
229 return ERR_FILE_NOT_FOUND;
231 ChannelID* channel_id = it->second;
232 key_result->reset(channel_id->key()->Copy());
234 return OK;
237 void DefaultChannelIDStore::SetChannelID(scoped_ptr<ChannelID> channel_id) {
238 auto task = new SetChannelIDTask(channel_id.Pass());
239 RunOrEnqueueTask(scoped_ptr<Task>(task));
242 void DefaultChannelIDStore::DeleteChannelID(
243 const std::string& server_identifier,
244 const base::Closure& callback) {
245 RunOrEnqueueTask(scoped_ptr<Task>(
246 new DeleteChannelIDTask(server_identifier, callback)));
249 void DefaultChannelIDStore::DeleteAllCreatedBetween(
250 base::Time delete_begin,
251 base::Time delete_end,
252 const base::Closure& callback) {
253 RunOrEnqueueTask(scoped_ptr<Task>(
254 new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
257 void DefaultChannelIDStore::DeleteAll(
258 const base::Closure& callback) {
259 DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
262 void DefaultChannelIDStore::GetAllChannelIDs(
263 const GetChannelIDListCallback& callback) {
264 RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
267 int DefaultChannelIDStore::GetChannelIDCount() {
268 DCHECK(CalledOnValidThread());
270 return channel_ids_.size();
273 void DefaultChannelIDStore::SetForceKeepSessionState() {
274 DCHECK(CalledOnValidThread());
275 InitIfNecessary();
277 if (store_.get())
278 store_->SetForceKeepSessionState();
281 DefaultChannelIDStore::~DefaultChannelIDStore() {
282 DeleteAllInMemory();
285 void DefaultChannelIDStore::DeleteAllInMemory() {
286 DCHECK(CalledOnValidThread());
288 for (ChannelIDMap::iterator it = channel_ids_.begin();
289 it != channel_ids_.end(); ++it) {
290 delete it->second;
292 channel_ids_.clear();
295 void DefaultChannelIDStore::InitStore() {
296 DCHECK(CalledOnValidThread());
297 DCHECK(store_.get()) << "Store must exist to initialize";
298 DCHECK(!loaded_);
300 store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
301 weak_ptr_factory_.GetWeakPtr()));
304 void DefaultChannelIDStore::OnLoaded(
305 scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
306 DCHECK(CalledOnValidThread());
308 for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
309 it != channel_ids->end(); ++it) {
310 DCHECK(channel_ids_.find((*it)->server_identifier()) ==
311 channel_ids_.end());
312 channel_ids_[(*it)->server_identifier()] = *it;
314 channel_ids->weak_clear();
316 loaded_ = true;
318 base::TimeDelta wait_time;
319 if (!waiting_tasks_.empty())
320 wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
321 DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
322 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
323 wait_time,
324 base::TimeDelta::FromMilliseconds(1),
325 base::TimeDelta::FromMinutes(1),
326 50);
327 UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
328 waiting_tasks_.size());
331 for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
332 i != waiting_tasks_.end(); ++i)
333 (*i)->Run(this);
334 waiting_tasks_.clear();
337 void DefaultChannelIDStore::SyncSetChannelID(scoped_ptr<ChannelID> channel_id) {
338 DCHECK(CalledOnValidThread());
339 DCHECK(loaded_);
341 InternalDeleteChannelID(channel_id->server_identifier());
342 InternalInsertChannelID(channel_id.Pass());
345 void DefaultChannelIDStore::SyncDeleteChannelID(
346 const std::string& server_identifier) {
347 DCHECK(CalledOnValidThread());
348 DCHECK(loaded_);
349 InternalDeleteChannelID(server_identifier);
352 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
353 base::Time delete_begin,
354 base::Time delete_end) {
355 DCHECK(CalledOnValidThread());
356 DCHECK(loaded_);
357 for (ChannelIDMap::iterator it = channel_ids_.begin();
358 it != channel_ids_.end();) {
359 ChannelIDMap::iterator cur = it;
360 ++it;
361 ChannelID* channel_id = cur->second;
362 if ((delete_begin.is_null() ||
363 channel_id->creation_time() >= delete_begin) &&
364 (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
365 if (store_.get())
366 store_->DeleteChannelID(*channel_id);
367 delete channel_id;
368 channel_ids_.erase(cur);
373 void DefaultChannelIDStore::SyncGetAllChannelIDs(
374 ChannelIDList* channel_id_list) {
375 DCHECK(CalledOnValidThread());
376 DCHECK(loaded_);
377 for (ChannelIDMap::iterator it = channel_ids_.begin();
378 it != channel_ids_.end(); ++it)
379 channel_id_list->push_back(*it->second);
382 void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
383 DCHECK(CalledOnValidThread());
384 DCHECK(!loaded_);
385 if (waiting_tasks_.empty())
386 waiting_tasks_start_time_ = base::TimeTicks::Now();
387 waiting_tasks_.push_back(task.Pass());
390 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
391 DCHECK(CalledOnValidThread());
392 InitIfNecessary();
394 if (!loaded_) {
395 EnqueueTask(task.Pass());
396 return;
399 task->Run(this);
402 void DefaultChannelIDStore::InternalDeleteChannelID(
403 const std::string& server_identifier) {
404 DCHECK(CalledOnValidThread());
405 DCHECK(loaded_);
407 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
408 if (it == channel_ids_.end())
409 return; // There is nothing to delete.
411 ChannelID* channel_id = it->second;
412 if (store_.get())
413 store_->DeleteChannelID(*channel_id);
414 channel_ids_.erase(it);
415 delete channel_id;
418 void DefaultChannelIDStore::InternalInsertChannelID(
419 scoped_ptr<ChannelID> channel_id) {
420 DCHECK(CalledOnValidThread());
421 DCHECK(loaded_);
423 if (store_.get())
424 store_->AddChannelID(*(channel_id.get()));
425 const std::string& server_identifier = channel_id->server_identifier();
426 channel_ids_[server_identifier] = channel_id.release();
429 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
431 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}
433 } // namespace net