1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/ssl/default_channel_id_store.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/metrics/histogram_macros.h"
10 #include "crypto/ec_private_key.h"
11 #include "net/base/net_errors.h"
15 // --------------------------------------------------------------------------
17 class DefaultChannelIDStore::Task
{
21 // Runs the task and invokes the client callback on the thread that
22 // originally constructed the task.
23 virtual void Run(DefaultChannelIDStore
* store
) = 0;
26 void InvokeCallback(base::Closure callback
) const;
29 DefaultChannelIDStore::Task::~Task() {
32 void DefaultChannelIDStore::Task::InvokeCallback(
33 base::Closure callback
) const {
34 if (!callback
.is_null())
38 // --------------------------------------------------------------------------
40 class DefaultChannelIDStore::GetChannelIDTask
41 : public DefaultChannelIDStore::Task
{
43 GetChannelIDTask(const std::string
& server_identifier
,
44 const GetChannelIDCallback
& callback
);
45 ~GetChannelIDTask() override
;
46 void Run(DefaultChannelIDStore
* store
) override
;
49 std::string server_identifier_
;
50 GetChannelIDCallback callback_
;
53 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
54 const std::string
& server_identifier
,
55 const GetChannelIDCallback
& callback
)
56 : server_identifier_(server_identifier
),
60 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
63 void DefaultChannelIDStore::GetChannelIDTask::Run(
64 DefaultChannelIDStore
* store
) {
65 scoped_ptr
<crypto::ECPrivateKey
> key_result
;
66 int err
= store
->GetChannelID(server_identifier_
, &key_result
,
67 GetChannelIDCallback());
68 DCHECK(err
!= ERR_IO_PENDING
);
70 InvokeCallback(base::Bind(callback_
, err
, server_identifier_
,
71 base::Passed(key_result
.Pass())));
74 // --------------------------------------------------------------------------
76 class DefaultChannelIDStore::SetChannelIDTask
77 : public DefaultChannelIDStore::Task
{
79 SetChannelIDTask(scoped_ptr
<ChannelID
> channel_id
);
80 ~SetChannelIDTask() override
;
81 void Run(DefaultChannelIDStore
* store
) override
;
84 scoped_ptr
<ChannelID
> channel_id_
;
87 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
88 scoped_ptr
<ChannelID
> channel_id
)
89 : channel_id_(channel_id
.Pass()) {
92 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
95 void DefaultChannelIDStore::SetChannelIDTask::Run(
96 DefaultChannelIDStore
* store
) {
97 store
->SyncSetChannelID(channel_id_
.Pass());
100 // --------------------------------------------------------------------------
101 // DeleteChannelIDTask
102 class DefaultChannelIDStore::DeleteChannelIDTask
103 : public DefaultChannelIDStore::Task
{
105 DeleteChannelIDTask(const std::string
& server_identifier
,
106 const base::Closure
& callback
);
107 ~DeleteChannelIDTask() override
;
108 void Run(DefaultChannelIDStore
* store
) override
;
111 std::string server_identifier_
;
112 base::Closure callback_
;
115 DefaultChannelIDStore::DeleteChannelIDTask::
117 const std::string
& server_identifier
,
118 const base::Closure
& callback
)
119 : server_identifier_(server_identifier
),
120 callback_(callback
) {
123 DefaultChannelIDStore::DeleteChannelIDTask::
124 ~DeleteChannelIDTask() {
127 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
128 DefaultChannelIDStore
* store
) {
129 store
->SyncDeleteChannelID(server_identifier_
);
131 InvokeCallback(callback_
);
134 // --------------------------------------------------------------------------
135 // DeleteAllCreatedBetweenTask
136 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
137 : public DefaultChannelIDStore::Task
{
139 DeleteAllCreatedBetweenTask(base::Time delete_begin
,
140 base::Time delete_end
,
141 const base::Closure
& callback
);
142 ~DeleteAllCreatedBetweenTask() override
;
143 void Run(DefaultChannelIDStore
* store
) override
;
146 base::Time delete_begin_
;
147 base::Time delete_end_
;
148 base::Closure callback_
;
151 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
152 DeleteAllCreatedBetweenTask(
153 base::Time delete_begin
,
154 base::Time delete_end
,
155 const base::Closure
& callback
)
156 : delete_begin_(delete_begin
),
157 delete_end_(delete_end
),
158 callback_(callback
) {
161 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
162 ~DeleteAllCreatedBetweenTask() {
165 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
166 DefaultChannelIDStore
* store
) {
167 store
->SyncDeleteAllCreatedBetween(delete_begin_
, delete_end_
);
169 InvokeCallback(callback_
);
172 // --------------------------------------------------------------------------
173 // GetAllChannelIDsTask
174 class DefaultChannelIDStore::GetAllChannelIDsTask
175 : public DefaultChannelIDStore::Task
{
177 explicit GetAllChannelIDsTask(const GetChannelIDListCallback
& callback
);
178 ~GetAllChannelIDsTask() override
;
179 void Run(DefaultChannelIDStore
* store
) override
;
182 std::string server_identifier_
;
183 GetChannelIDListCallback callback_
;
186 DefaultChannelIDStore::GetAllChannelIDsTask::
187 GetAllChannelIDsTask(const GetChannelIDListCallback
& callback
)
188 : callback_(callback
) {
191 DefaultChannelIDStore::GetAllChannelIDsTask::
192 ~GetAllChannelIDsTask() {
195 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
196 DefaultChannelIDStore
* store
) {
197 ChannelIDList key_list
;
198 store
->SyncGetAllChannelIDs(&key_list
);
200 InvokeCallback(base::Bind(callback_
, key_list
));
203 // --------------------------------------------------------------------------
204 // DefaultChannelIDStore
206 DefaultChannelIDStore::DefaultChannelIDStore(
207 PersistentStore
* store
)
208 : initialized_(false),
211 weak_ptr_factory_(this) {}
213 int DefaultChannelIDStore::GetChannelID(
214 const std::string
& server_identifier
,
215 scoped_ptr
<crypto::ECPrivateKey
>* key_result
,
216 const GetChannelIDCallback
& callback
) {
217 DCHECK(CalledOnValidThread());
221 EnqueueTask(scoped_ptr
<Task
>(
222 new GetChannelIDTask(server_identifier
, callback
)));
223 return ERR_IO_PENDING
;
226 ChannelIDMap::iterator it
= channel_ids_
.find(server_identifier
);
228 if (it
== channel_ids_
.end())
229 return ERR_FILE_NOT_FOUND
;
231 ChannelID
* channel_id
= it
->second
;
232 key_result
->reset(channel_id
->key()->Copy());
237 void DefaultChannelIDStore::SetChannelID(scoped_ptr
<ChannelID
> channel_id
) {
238 auto task
= new SetChannelIDTask(channel_id
.Pass());
239 RunOrEnqueueTask(scoped_ptr
<Task
>(task
));
242 void DefaultChannelIDStore::DeleteChannelID(
243 const std::string
& server_identifier
,
244 const base::Closure
& callback
) {
245 RunOrEnqueueTask(scoped_ptr
<Task
>(
246 new DeleteChannelIDTask(server_identifier
, callback
)));
249 void DefaultChannelIDStore::DeleteAllCreatedBetween(
250 base::Time delete_begin
,
251 base::Time delete_end
,
252 const base::Closure
& callback
) {
253 RunOrEnqueueTask(scoped_ptr
<Task
>(
254 new DeleteAllCreatedBetweenTask(delete_begin
, delete_end
, callback
)));
257 void DefaultChannelIDStore::DeleteAll(
258 const base::Closure
& callback
) {
259 DeleteAllCreatedBetween(base::Time(), base::Time(), callback
);
262 void DefaultChannelIDStore::GetAllChannelIDs(
263 const GetChannelIDListCallback
& callback
) {
264 RunOrEnqueueTask(scoped_ptr
<Task
>(new GetAllChannelIDsTask(callback
)));
267 int DefaultChannelIDStore::GetChannelIDCount() {
268 DCHECK(CalledOnValidThread());
270 return channel_ids_
.size();
273 void DefaultChannelIDStore::SetForceKeepSessionState() {
274 DCHECK(CalledOnValidThread());
278 store_
->SetForceKeepSessionState();
281 DefaultChannelIDStore::~DefaultChannelIDStore() {
285 void DefaultChannelIDStore::DeleteAllInMemory() {
286 DCHECK(CalledOnValidThread());
288 for (ChannelIDMap::iterator it
= channel_ids_
.begin();
289 it
!= channel_ids_
.end(); ++it
) {
292 channel_ids_
.clear();
295 void DefaultChannelIDStore::InitStore() {
296 DCHECK(CalledOnValidThread());
297 DCHECK(store_
.get()) << "Store must exist to initialize";
300 store_
->Load(base::Bind(&DefaultChannelIDStore::OnLoaded
,
301 weak_ptr_factory_
.GetWeakPtr()));
304 void DefaultChannelIDStore::OnLoaded(
305 scoped_ptr
<ScopedVector
<ChannelID
> > channel_ids
) {
306 DCHECK(CalledOnValidThread());
308 for (std::vector
<ChannelID
*>::const_iterator it
= channel_ids
->begin();
309 it
!= channel_ids
->end(); ++it
) {
310 DCHECK(channel_ids_
.find((*it
)->server_identifier()) ==
312 channel_ids_
[(*it
)->server_identifier()] = *it
;
314 channel_ids
->weak_clear();
318 base::TimeDelta wait_time
;
319 if (!waiting_tasks_
.empty())
320 wait_time
= base::TimeTicks::Now() - waiting_tasks_start_time_
;
321 DVLOG(1) << "Task delay " << wait_time
.InMilliseconds();
322 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
324 base::TimeDelta::FromMilliseconds(1),
325 base::TimeDelta::FromMinutes(1),
327 UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
328 waiting_tasks_
.size());
331 for (ScopedVector
<Task
>::iterator i
= waiting_tasks_
.begin();
332 i
!= waiting_tasks_
.end(); ++i
)
334 waiting_tasks_
.clear();
337 void DefaultChannelIDStore::SyncSetChannelID(scoped_ptr
<ChannelID
> channel_id
) {
338 DCHECK(CalledOnValidThread());
341 InternalDeleteChannelID(channel_id
->server_identifier());
342 InternalInsertChannelID(channel_id
.Pass());
345 void DefaultChannelIDStore::SyncDeleteChannelID(
346 const std::string
& server_identifier
) {
347 DCHECK(CalledOnValidThread());
349 InternalDeleteChannelID(server_identifier
);
352 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
353 base::Time delete_begin
,
354 base::Time delete_end
) {
355 DCHECK(CalledOnValidThread());
357 for (ChannelIDMap::iterator it
= channel_ids_
.begin();
358 it
!= channel_ids_
.end();) {
359 ChannelIDMap::iterator cur
= it
;
361 ChannelID
* channel_id
= cur
->second
;
362 if ((delete_begin
.is_null() ||
363 channel_id
->creation_time() >= delete_begin
) &&
364 (delete_end
.is_null() || channel_id
->creation_time() < delete_end
)) {
366 store_
->DeleteChannelID(*channel_id
);
368 channel_ids_
.erase(cur
);
373 void DefaultChannelIDStore::SyncGetAllChannelIDs(
374 ChannelIDList
* channel_id_list
) {
375 DCHECK(CalledOnValidThread());
377 for (ChannelIDMap::iterator it
= channel_ids_
.begin();
378 it
!= channel_ids_
.end(); ++it
)
379 channel_id_list
->push_back(*it
->second
);
382 void DefaultChannelIDStore::EnqueueTask(scoped_ptr
<Task
> task
) {
383 DCHECK(CalledOnValidThread());
385 if (waiting_tasks_
.empty())
386 waiting_tasks_start_time_
= base::TimeTicks::Now();
387 waiting_tasks_
.push_back(task
.Pass());
390 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr
<Task
> task
) {
391 DCHECK(CalledOnValidThread());
395 EnqueueTask(task
.Pass());
402 void DefaultChannelIDStore::InternalDeleteChannelID(
403 const std::string
& server_identifier
) {
404 DCHECK(CalledOnValidThread());
407 ChannelIDMap::iterator it
= channel_ids_
.find(server_identifier
);
408 if (it
== channel_ids_
.end())
409 return; // There is nothing to delete.
411 ChannelID
* channel_id
= it
->second
;
413 store_
->DeleteChannelID(*channel_id
);
414 channel_ids_
.erase(it
);
418 void DefaultChannelIDStore::InternalInsertChannelID(
419 scoped_ptr
<ChannelID
> channel_id
) {
420 DCHECK(CalledOnValidThread());
424 store_
->AddChannelID(*(channel_id
.get()));
425 const std::string
& server_identifier
= channel_id
->server_identifier();
426 channel_ids_
[server_identifier
] = channel_id
.release();
429 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
431 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}