1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/chromeos/certificate_provider/certificate_provider_service.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/stl_util.h"
13 #include "base/strings/string_piece.h"
14 #include "base/task_runner.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "chrome/browser/chromeos/certificate_provider/certificate_provider.h"
17 #include "net/base/net_errors.h"
18 #include "net/ssl/client_key_store.h"
24 void PostSignResultToTaskRunner(
25 const scoped_refptr
<base::TaskRunner
>& target_task_runner
,
26 const net::SSLPrivateKey::SignCallback
& callback
,
28 const std::vector
<uint8_t>& signature
) {
29 target_task_runner
->PostTask(FROM_HERE
,
30 base::Bind(callback
, error
, signature
));
33 void PostCertificatesToTaskRunner(
34 const scoped_refptr
<base::TaskRunner
>& target_task_runner
,
35 const base::Callback
<void(const net::CertificateList
&)>& callback
,
36 const net::CertificateList
& certs
) {
37 target_task_runner
->PostTask(FROM_HERE
, base::Bind(callback
, certs
));
42 class CertificateProviderService::CertKeyProviderImpl
43 : public net::ClientKeyStore::CertKeyProvider
{
45 // |certificate_map| must outlive this provider. |service| must be
46 // dereferenceable on |service_task_runner|.
47 // This provider may be accessed from any thread. Methods and destructor must
48 // never be called concurrently.
50 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
51 const base::WeakPtr
<CertificateProviderService
>& service
,
52 certificate_provider::ThreadSafeCertificateMap
* certificate_map
);
53 ~CertKeyProviderImpl() override
;
55 bool GetCertificateKey(const net::X509Certificate
& cert
,
56 scoped_ptr
<net::SSLPrivateKey
>* private_key
) override
;
59 const scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
60 // Must be dereferenced on |service_task_runner_| only.
61 base::WeakPtr
<CertificateProviderService
> service_
;
62 certificate_provider::ThreadSafeCertificateMap
* const certificate_map_
;
64 DISALLOW_COPY_AND_ASSIGN(CertKeyProviderImpl
);
67 class CertificateProviderService::CertificateProviderImpl
68 : public CertificateProvider
{
70 // Any calls back to |service| will be posted to |service_task_runner|.
71 // |service| must be dereferenceable on |service_task_runner|.
72 // This provider is not thread safe, but can be used on any thread.
73 CertificateProviderImpl(
74 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
75 const base::WeakPtr
<CertificateProviderService
>& service
);
76 ~CertificateProviderImpl() override
;
78 void GetCertificates(const base::Callback
<void(const net::CertificateList
&)>&
81 scoped_ptr
<CertificateProvider
> Copy() override
;
84 static void GetCertificatesOnServiceThread(
85 const base::WeakPtr
<CertificateProviderService
>& service
,
86 const base::Callback
<void(const net::CertificateList
&)>& callback
);
88 const scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
89 // Must be dereferenced on |service_task_runner_| only.
90 const base::WeakPtr
<CertificateProviderService
> service_
;
92 DISALLOW_COPY_AND_ASSIGN(CertificateProviderImpl
);
95 // Implements an SSLPrivateKey backed by the signing function exposed by an
96 // extension through the certificateProvider API.
97 // Objects of this class must be used on a single thread. Any thread is allowed.
98 class CertificateProviderService::SSLPrivateKey
: public net::SSLPrivateKey
{
100 // Any calls back to |service| will be posted to |service_task_runner|.
101 // |service| must be dereferenceable on |service_task_runner|.
103 const std::string
& extension_id
,
104 const CertificateInfo
& cert_info
,
105 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
106 const base::WeakPtr
<CertificateProviderService
>& service
);
107 ~SSLPrivateKey() override
;
109 // net::SSLPrivateKey:
110 Type
GetType() override
;
111 bool SupportsHash(Hash hash
) override
;
112 size_t GetMaxSignatureLengthInBytes() override
;
113 void SignDigest(Hash hash
,
114 const base::StringPiece
& input
,
115 const SignCallback
& callback
) override
;
118 static void SignDigestOnServiceTaskRunner(
119 const base::WeakPtr
<CertificateProviderService
>& service
,
120 const std::string
& extension_id
,
121 const scoped_refptr
<net::X509Certificate
>& certificate
,
123 const std::string
& input
,
124 const SignCallback
& callback
);
126 void DidSignDigest(const SignCallback
& callback
,
128 const std::vector
<uint8_t>& signature
);
130 const std::string extension_id_
;
131 const CertificateInfo cert_info_
;
132 scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
133 // Must be dereferenced on |service_task_runner_| only.
134 const base::WeakPtr
<CertificateProviderService
> service_
;
135 base::ThreadChecker thread_checker_
;
136 base::WeakPtrFactory
<SSLPrivateKey
> weak_factory_
;
138 DISALLOW_COPY_AND_ASSIGN(SSLPrivateKey
);
141 CertificateProviderService::CertKeyProviderImpl::CertKeyProviderImpl(
142 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
143 const base::WeakPtr
<CertificateProviderService
>& service
,
144 certificate_provider::ThreadSafeCertificateMap
* certificate_map
)
145 : service_task_runner_(service_task_runner
),
147 certificate_map_(certificate_map
) {}
149 CertificateProviderService::CertKeyProviderImpl::~CertKeyProviderImpl() {}
151 bool CertificateProviderService::CertKeyProviderImpl::GetCertificateKey(
152 const net::X509Certificate
& cert
,
153 scoped_ptr
<net::SSLPrivateKey
>* private_key
) {
154 bool is_currently_provided
= false;
155 CertificateInfo info
;
156 std::string extension_id
;
157 certificate_map_
->LookUpCertificate(cert
, &is_currently_provided
, &info
,
159 if (!is_currently_provided
)
163 new SSLPrivateKey(extension_id
, info
, service_task_runner_
, service_
));
167 CertificateProviderService::CertificateProviderImpl::CertificateProviderImpl(
168 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
169 const base::WeakPtr
<CertificateProviderService
>& service
)
170 : service_task_runner_(service_task_runner
), service_(service
) {}
172 CertificateProviderService::CertificateProviderImpl::
173 ~CertificateProviderImpl() {}
175 void CertificateProviderService::CertificateProviderImpl::GetCertificates(
176 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
177 const scoped_refptr
<base::TaskRunner
> source_task_runner
=
178 base::ThreadTaskRunnerHandle::Get();
179 const base::Callback
<void(const net::CertificateList
&)>
180 callback_from_service_thread
= base::Bind(&PostCertificatesToTaskRunner
,
181 source_task_runner
, callback
);
183 service_task_runner_
->PostTask(
184 FROM_HERE
, base::Bind(&GetCertificatesOnServiceThread
, service_
,
185 callback_from_service_thread
));
188 scoped_ptr
<CertificateProvider
>
189 CertificateProviderService::CertificateProviderImpl::Copy() {
190 return make_scoped_ptr(
191 new CertificateProviderImpl(service_task_runner_
, service_
));
195 void CertificateProviderService::CertificateProviderImpl::
196 GetCertificatesOnServiceThread(
197 const base::WeakPtr
<CertificateProviderService
>& service
,
198 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
200 callback
.Run(net::CertificateList());
203 service
->GetCertificatesFromExtensions(callback
);
206 CertificateProviderService::SSLPrivateKey::SSLPrivateKey(
207 const std::string
& extension_id
,
208 const CertificateInfo
& cert_info
,
209 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
210 const base::WeakPtr
<CertificateProviderService
>& service
)
211 : extension_id_(extension_id
),
212 cert_info_(cert_info
),
213 service_task_runner_(service_task_runner
),
215 weak_factory_(this) {
216 // This constructor is called on |service_task_runner|. Only subsequent calls
217 // to member functions have to be on a common thread.
218 thread_checker_
.DetachFromThread();
221 CertificateProviderService::SSLPrivateKey::~SSLPrivateKey() {
222 DCHECK(thread_checker_
.CalledOnValidThread());
225 CertificateProviderService::SSLPrivateKey::Type
226 CertificateProviderService::SSLPrivateKey::GetType() {
227 DCHECK(thread_checker_
.CalledOnValidThread());
228 return cert_info_
.type
;
231 bool CertificateProviderService::SSLPrivateKey::SupportsHash(Hash hash
) {
232 DCHECK(thread_checker_
.CalledOnValidThread());
233 return ContainsValue(cert_info_
.supported_hashes
, hash
);
237 CertificateProviderService::SSLPrivateKey::GetMaxSignatureLengthInBytes() {
238 DCHECK(thread_checker_
.CalledOnValidThread());
239 return cert_info_
.max_signature_length_in_bytes
;
243 void CertificateProviderService::SSLPrivateKey::SignDigestOnServiceTaskRunner(
244 const base::WeakPtr
<CertificateProviderService
>& service
,
245 const std::string
& extension_id
,
246 const scoped_refptr
<net::X509Certificate
>& certificate
,
248 const std::string
& input
,
249 const SignCallback
& callback
) {
251 const std::vector
<uint8_t> no_signature
;
252 callback
.Run(net::ERR_FAILED
, no_signature
);
255 service
->RequestSignatureFromExtension(extension_id
, certificate
, hash
, input
,
259 void CertificateProviderService::SSLPrivateKey::SignDigest(
261 const base::StringPiece
& input
,
262 const SignCallback
& callback
) {
263 DCHECK(thread_checker_
.CalledOnValidThread());
264 const scoped_refptr
<base::TaskRunner
> source_task_runner
=
265 base::ThreadTaskRunnerHandle::Get();
266 const SignCallback bound_callback
=
267 // The CertificateProviderService calls back on another thread, so post
268 // back to the current thread.
269 base::Bind(&PostSignResultToTaskRunner
, source_task_runner
,
270 // Drop the result and don't call back if this key handle is
271 // destroyed in the meantime.
272 base::Bind(&SSLPrivateKey::DidSignDigest
,
273 weak_factory_
.GetWeakPtr(), callback
));
275 service_task_runner_
->PostTask(
276 FROM_HERE
, base::Bind(&SSLPrivateKey::SignDigestOnServiceTaskRunner
,
277 service_
, extension_id_
, cert_info_
.certificate
,
278 hash
, input
.as_string(), bound_callback
));
281 void CertificateProviderService::SSLPrivateKey::DidSignDigest(
282 const SignCallback
& callback
,
284 const std::vector
<uint8_t>& signature
) {
285 DCHECK(thread_checker_
.CalledOnValidThread());
286 callback
.Run(error
, signature
);
289 CertificateProviderService::CertificateProviderService()
290 : weak_factory_(this) {}
292 CertificateProviderService::~CertificateProviderService() {
293 DCHECK(thread_checker_
.CalledOnValidThread());
295 // ClientKeyStore serializes access to |cert_key_provider_|.
296 // Once RemoveProvider() returns, it is guaranteed that there are no more
297 // accesses to |cert_key_provider_| in flight and no references to
298 // |cert_key_provider_| are remaining. This service will hold the last
299 // reference to |cert_key_provider_|.
300 net::ClientKeyStore::GetInstance()->RemoveProvider(cert_key_provider_
.get());
301 cert_key_provider_
.reset();
304 void CertificateProviderService::SetDelegate(scoped_ptr
<Delegate
> delegate
) {
305 DCHECK(thread_checker_
.CalledOnValidThread());
309 delegate_
= delegate
.Pass();
310 cert_key_provider_
.reset(
311 new CertKeyProviderImpl(base::ThreadTaskRunnerHandle::Get(),
312 weak_factory_
.GetWeakPtr(), &certificate_map_
));
313 net::ClientKeyStore::GetInstance()->AddProvider(cert_key_provider_
.get());
316 bool CertificateProviderService::SetCertificatesProvidedByExtension(
317 const std::string
& extension_id
,
319 const CertificateInfoList
& certificate_infos
) {
320 DCHECK(thread_checker_
.CalledOnValidThread());
322 bool completed
= false;
323 if (!certificate_requests_
.SetCertificates(extension_id
, cert_request_id
,
324 certificate_infos
, &completed
)) {
325 DLOG(WARNING
) << "Unexpected reply of extension " << extension_id
326 << " to request " << cert_request_id
;
330 std::map
<std::string
, CertificateInfoList
> certificates
;
331 base::Callback
<void(const net::CertificateList
&)> callback
;
332 certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
334 UpdateCertificatesAndRun(certificates
, callback
);
339 void CertificateProviderService::ReplyToSignRequest(
340 const std::string
& extension_id
,
342 const std::vector
<uint8_t>& signature
) {
343 DCHECK(thread_checker_
.CalledOnValidThread());
345 net::SSLPrivateKey::SignCallback callback
;
346 if (!sign_requests_
.RemoveRequest(extension_id
, sign_request_id
, &callback
)) {
347 LOG(ERROR
) << "request id unknown.";
348 // Maybe multiple replies to the same request.
352 const net::Error error_code
= signature
.empty() ? net::ERR_FAILED
: net::OK
;
353 callback
.Run(error_code
, signature
);
356 bool CertificateProviderService::LookUpCertificate(
357 const net::X509Certificate
& cert
,
359 std::string
* extension_id
) {
360 DCHECK(thread_checker_
.CalledOnValidThread());
362 CertificateInfo unused_info
;
363 return certificate_map_
.LookUpCertificate(cert
, has_extension
, &unused_info
,
367 scoped_ptr
<CertificateProvider
>
368 CertificateProviderService::CreateCertificateProvider() {
369 DCHECK(thread_checker_
.CalledOnValidThread());
371 return make_scoped_ptr(new CertificateProviderImpl(
372 base::ThreadTaskRunnerHandle::Get(), weak_factory_
.GetWeakPtr()));
375 void CertificateProviderService::OnExtensionUnloaded(
376 const std::string
& extension_id
) {
377 DCHECK(thread_checker_
.CalledOnValidThread());
379 for (const int cert_request_id
:
380 certificate_requests_
.DropExtension(extension_id
)) {
381 std::map
<std::string
, CertificateInfoList
> certificates
;
382 base::Callback
<void(const net::CertificateList
&)> callback
;
383 certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
385 UpdateCertificatesAndRun(certificates
, callback
);
388 certificate_map_
.RemoveExtension(extension_id
);
390 for (auto callback
: sign_requests_
.RemoveAllRequests(extension_id
))
391 callback
.Run(net::ERR_FAILED
, std::vector
<uint8_t>());
394 void CertificateProviderService::GetCertificatesFromExtensions(
395 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
396 DCHECK(thread_checker_
.CalledOnValidThread());
398 const std::vector
<std::string
> provider_extensions(
399 delegate_
->CertificateProviderExtensions());
401 if (provider_extensions
.empty()) {
402 DVLOG(2) << "No provider extensions left, clear all certificates.";
403 UpdateCertificatesAndRun(std::map
<std::string
, CertificateInfoList
>(),
408 const int cert_request_id
= certificate_requests_
.AddRequest(
409 provider_extensions
, callback
,
410 base::Bind(&CertificateProviderService::TerminateCertificateRequest
,
411 base::Unretained(this)));
413 DVLOG(2) << "Start certificate request " << cert_request_id
;
414 delegate_
->BroadcastCertificateRequest(cert_request_id
);
417 void CertificateProviderService::UpdateCertificatesAndRun(
418 const std::map
<std::string
, CertificateInfoList
>& extension_to_certificates
,
419 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
420 DCHECK(thread_checker_
.CalledOnValidThread());
422 // Extensions are removed from the service's state when they're unloaded.
423 // Any remaining extension is assumed to be enabled.
424 certificate_map_
.Update(extension_to_certificates
);
426 net::CertificateList all_certs
;
427 for (const auto& entry
: extension_to_certificates
) {
428 for (const CertificateInfo
& cert_info
: entry
.second
)
429 all_certs
.push_back(cert_info
.certificate
);
432 callback
.Run(all_certs
);
435 void CertificateProviderService::TerminateCertificateRequest(
436 int cert_request_id
) {
437 DCHECK(thread_checker_
.CalledOnValidThread());
439 std::map
<std::string
, CertificateInfoList
> certificates
;
440 base::Callback
<void(const net::CertificateList
&)> callback
;
441 if (!certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
443 DLOG(WARNING
) << "Request id " << cert_request_id
<< " unknown.";
447 DVLOG(1) << "Time out certificate request " << cert_request_id
;
448 UpdateCertificatesAndRun(certificates
, callback
);
451 void CertificateProviderService::RequestSignatureFromExtension(
452 const std::string
& extension_id
,
453 const scoped_refptr
<net::X509Certificate
>& certificate
,
454 net::SSLPrivateKey::Hash hash
,
455 const std::string
& digest
,
456 const net::SSLPrivateKey::SignCallback
& callback
) {
457 DCHECK(thread_checker_
.CalledOnValidThread());
459 const int sign_request_id
= sign_requests_
.AddRequest(extension_id
, callback
);
460 if (!delegate_
->DispatchSignRequestToExtension(extension_id
, sign_request_id
,
461 hash
, certificate
, digest
)) {
462 sign_requests_
.RemoveRequest(extension_id
, sign_request_id
,
463 nullptr /* callback */);
464 callback
.Run(net::ERR_FAILED
, std::vector
<uint8_t>());
468 } // namespace chromeos