1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/chromeos/certificate_provider/certificate_provider_service.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/stl_util.h"
13 #include "base/strings/string_piece.h"
14 #include "base/task_runner.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "chrome/browser/chromeos/certificate_provider/certificate_provider.h"
17 #include "net/base/net_errors.h"
18 #include "net/ssl/client_key_store.h"
24 void PostSignResultToTaskRunner(
25 const scoped_refptr
<base::TaskRunner
>& target_task_runner
,
26 const net::SSLPrivateKey::SignCallback
& callback
,
28 const std::vector
<uint8_t>& signature
) {
29 target_task_runner
->PostTask(FROM_HERE
,
30 base::Bind(callback
, error
, signature
));
33 void PostCertificatesToTaskRunner(
34 const scoped_refptr
<base::TaskRunner
>& target_task_runner
,
35 const base::Callback
<void(const net::CertificateList
&)>& callback
,
36 const net::CertificateList
& certs
) {
37 target_task_runner
->PostTask(FROM_HERE
, base::Bind(callback
, certs
));
42 class CertificateProviderService::CertKeyProviderImpl
43 : public net::ClientKeyStore::CertKeyProvider
{
45 // |certificate_map| must outlive this provider. |service| must be
46 // dereferenceable on |service_task_runner|.
47 // This provider may be accessed from any thread. Methods and destructor must
48 // never be called concurrently.
50 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
51 const base::WeakPtr
<CertificateProviderService
>& service
,
52 certificate_provider::ThreadSafeCertificateMap
* certificate_map
);
53 ~CertKeyProviderImpl() override
;
55 bool GetCertificateKey(const net::X509Certificate
& cert
,
56 scoped_ptr
<net::SSLPrivateKey
>* private_key
) override
;
59 const scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
60 // Must be dereferenced on |service_task_runner_| only.
61 base::WeakPtr
<CertificateProviderService
> service_
;
62 certificate_provider::ThreadSafeCertificateMap
* const certificate_map_
;
64 DISALLOW_COPY_AND_ASSIGN(CertKeyProviderImpl
);
67 class CertificateProviderService::CertificateProviderImpl
68 : public CertificateProvider
{
70 // Any calls back to |service| will be posted to |service_task_runner|.
71 // |service| must be dereferenceable on |service_task_runner|.
72 // This provider is not thread safe, but can be used on any thread.
73 CertificateProviderImpl(
74 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
75 const base::WeakPtr
<CertificateProviderService
>& service
);
76 ~CertificateProviderImpl() override
;
78 void GetCertificates(const base::Callback
<void(const net::CertificateList
&)>&
81 scoped_ptr
<CertificateProvider
> Copy() override
;
84 static void GetCertificatesOnServiceThread(
85 const base::WeakPtr
<CertificateProviderService
>& service
,
86 const base::Callback
<void(const net::CertificateList
&)>& callback
);
88 const scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
89 // Must be dereferenced on |service_task_runner_| only.
90 const base::WeakPtr
<CertificateProviderService
> service_
;
92 DISALLOW_COPY_AND_ASSIGN(CertificateProviderImpl
);
95 // Implements an SSLPrivateKey backed by the signing function exposed by an
96 // extension through the certificateProvider API.
97 // Objects of this class must be used on a single thread. Any thread is allowed.
98 class CertificateProviderService::SSLPrivateKey
: public net::SSLPrivateKey
{
100 // Any calls back to |service| will be posted to |service_task_runner|.
101 // |service| must be dereferenceable on |service_task_runner|.
103 const std::string
& extension_id
,
104 const CertificateInfo
& cert_info
,
105 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
106 const base::WeakPtr
<CertificateProviderService
>& service
);
107 ~SSLPrivateKey() override
;
109 // net::SSLPrivateKey:
110 Type
GetType() override
;
111 bool SupportsHash(Hash hash
) override
;
112 size_t GetMaxSignatureLengthInBytes() override
;
113 void SignDigest(Hash hash
,
114 const base::StringPiece
& input
,
115 const SignCallback
& callback
) override
;
118 static void SignDigestOnServiceTaskRunner(
119 const base::WeakPtr
<CertificateProviderService
>& service
,
120 const std::string
& extension_id
,
121 const scoped_refptr
<net::X509Certificate
>& certificate
,
123 const std::string
& input
,
124 const SignCallback
& callback
);
126 void DidSignDigest(const SignCallback
& callback
,
128 const std::vector
<uint8_t>& signature
);
130 const std::string extension_id_
;
131 const CertificateInfo cert_info_
;
132 scoped_refptr
<base::SequencedTaskRunner
> service_task_runner_
;
133 // Must be dereferenced on |service_task_runner_| only.
134 const base::WeakPtr
<CertificateProviderService
> service_
;
135 base::ThreadChecker thread_checker_
;
136 base::WeakPtrFactory
<SSLPrivateKey
> weak_factory_
;
138 DISALLOW_COPY_AND_ASSIGN(SSLPrivateKey
);
141 CertificateProviderService::CertKeyProviderImpl::CertKeyProviderImpl(
142 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
143 const base::WeakPtr
<CertificateProviderService
>& service
,
144 certificate_provider::ThreadSafeCertificateMap
* certificate_map
)
145 : service_task_runner_(service_task_runner
),
147 certificate_map_(certificate_map
) {}
149 CertificateProviderService::CertKeyProviderImpl::~CertKeyProviderImpl() {}
151 bool CertificateProviderService::CertKeyProviderImpl::GetCertificateKey(
152 const net::X509Certificate
& cert
,
153 scoped_ptr
<net::SSLPrivateKey
>* private_key
) {
154 CertificateInfo info
;
155 std::string extension_id
;
156 if (!certificate_map_
->LookUpCertificate(cert
, &info
, &extension_id
))
160 new SSLPrivateKey(extension_id
, info
, service_task_runner_
, service_
));
164 CertificateProviderService::CertificateProviderImpl::CertificateProviderImpl(
165 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
166 const base::WeakPtr
<CertificateProviderService
>& service
)
167 : service_task_runner_(service_task_runner
), service_(service
) {}
169 CertificateProviderService::CertificateProviderImpl::
170 ~CertificateProviderImpl() {}
172 void CertificateProviderService::CertificateProviderImpl::GetCertificates(
173 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
174 const scoped_refptr
<base::TaskRunner
> source_task_runner
=
175 base::ThreadTaskRunnerHandle::Get();
176 const base::Callback
<void(const net::CertificateList
&)>
177 callback_from_service_thread
= base::Bind(&PostCertificatesToTaskRunner
,
178 source_task_runner
, callback
);
180 service_task_runner_
->PostTask(
181 FROM_HERE
, base::Bind(&GetCertificatesOnServiceThread
, service_
,
182 callback_from_service_thread
));
185 scoped_ptr
<CertificateProvider
>
186 CertificateProviderService::CertificateProviderImpl::Copy() {
187 return make_scoped_ptr(
188 new CertificateProviderImpl(service_task_runner_
, service_
));
192 void CertificateProviderService::CertificateProviderImpl::
193 GetCertificatesOnServiceThread(
194 const base::WeakPtr
<CertificateProviderService
>& service
,
195 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
197 callback
.Run(net::CertificateList());
200 service
->GetCertificatesFromExtensions(callback
);
203 CertificateProviderService::SSLPrivateKey::SSLPrivateKey(
204 const std::string
& extension_id
,
205 const CertificateInfo
& cert_info
,
206 const scoped_refptr
<base::SequencedTaskRunner
>& service_task_runner
,
207 const base::WeakPtr
<CertificateProviderService
>& service
)
208 : extension_id_(extension_id
),
209 cert_info_(cert_info
),
210 service_task_runner_(service_task_runner
),
212 weak_factory_(this) {
213 // This constructor is called on |service_task_runner|. Only subsequent calls
214 // to member functions have to be on a common thread.
215 thread_checker_
.DetachFromThread();
218 CertificateProviderService::SSLPrivateKey::~SSLPrivateKey() {
219 DCHECK(thread_checker_
.CalledOnValidThread());
222 CertificateProviderService::SSLPrivateKey::Type
223 CertificateProviderService::SSLPrivateKey::GetType() {
224 DCHECK(thread_checker_
.CalledOnValidThread());
225 return cert_info_
.type
;
228 bool CertificateProviderService::SSLPrivateKey::SupportsHash(Hash hash
) {
229 DCHECK(thread_checker_
.CalledOnValidThread());
230 return ContainsValue(cert_info_
.supported_hashes
, hash
);
234 CertificateProviderService::SSLPrivateKey::GetMaxSignatureLengthInBytes() {
235 DCHECK(thread_checker_
.CalledOnValidThread());
236 return cert_info_
.max_signature_length_in_bytes
;
240 void CertificateProviderService::SSLPrivateKey::SignDigestOnServiceTaskRunner(
241 const base::WeakPtr
<CertificateProviderService
>& service
,
242 const std::string
& extension_id
,
243 const scoped_refptr
<net::X509Certificate
>& certificate
,
245 const std::string
& input
,
246 const SignCallback
& callback
) {
248 const std::vector
<uint8_t> no_signature
;
249 callback
.Run(net::ERR_FAILED
, no_signature
);
252 service
->RequestSignatureFromExtension(extension_id
, certificate
, hash
, input
,
256 void CertificateProviderService::SSLPrivateKey::SignDigest(
258 const base::StringPiece
& input
,
259 const SignCallback
& callback
) {
260 DCHECK(thread_checker_
.CalledOnValidThread());
261 const scoped_refptr
<base::TaskRunner
> source_task_runner
=
262 base::ThreadTaskRunnerHandle::Get();
263 const SignCallback bound_callback
=
264 // The CertificateProviderService calls back on another thread, so post
265 // back to the current thread.
266 base::Bind(&PostSignResultToTaskRunner
, source_task_runner
,
267 // Drop the result and don't call back if this key handle is
268 // destroyed in the meantime.
269 base::Bind(&SSLPrivateKey::DidSignDigest
,
270 weak_factory_
.GetWeakPtr(), callback
));
272 service_task_runner_
->PostTask(
273 FROM_HERE
, base::Bind(&SSLPrivateKey::SignDigestOnServiceTaskRunner
,
274 service_
, extension_id_
, cert_info_
.certificate
,
275 hash
, input
.as_string(), bound_callback
));
278 void CertificateProviderService::SSLPrivateKey::DidSignDigest(
279 const SignCallback
& callback
,
281 const std::vector
<uint8_t>& signature
) {
282 DCHECK(thread_checker_
.CalledOnValidThread());
283 callback
.Run(error
, signature
);
286 CertificateProviderService::CertificateProviderService()
287 : weak_factory_(this) {}
289 CertificateProviderService::~CertificateProviderService() {
290 DCHECK(thread_checker_
.CalledOnValidThread());
292 // ClientKeyStore serializes access to |cert_key_provider_|.
293 // Once RemoveProvider() returns, it is guaranteed that there are no more
294 // accesses to |cert_key_provider_| in flight and no references to
295 // |cert_key_provider_| are remaining. This service will hold the last
296 // reference to |cert_key_provider_|.
297 net::ClientKeyStore::GetInstance()->RemoveProvider(cert_key_provider_
.get());
298 cert_key_provider_
.reset();
301 void CertificateProviderService::SetDelegate(scoped_ptr
<Delegate
> delegate
) {
302 DCHECK(thread_checker_
.CalledOnValidThread());
306 delegate_
= delegate
.Pass();
307 cert_key_provider_
.reset(
308 new CertKeyProviderImpl(base::ThreadTaskRunnerHandle::Get(),
309 weak_factory_
.GetWeakPtr(), &certificate_map_
));
310 net::ClientKeyStore::GetInstance()->AddProvider(cert_key_provider_
.get());
313 bool CertificateProviderService::SetCertificatesProvidedByExtension(
314 const std::string
& extension_id
,
316 const CertificateInfoList
& certificate_infos
) {
317 DCHECK(thread_checker_
.CalledOnValidThread());
319 bool completed
= false;
320 if (!certificate_requests_
.SetCertificates(extension_id
, cert_request_id
,
321 certificate_infos
, &completed
)) {
322 DLOG(WARNING
) << "Unexpected reply of extension " << extension_id
323 << " to request " << cert_request_id
;
327 std::map
<std::string
, CertificateInfoList
> certificates
;
328 base::Callback
<void(const net::CertificateList
&)> callback
;
329 certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
331 UpdateCertificatesAndRun(certificates
, callback
);
336 void CertificateProviderService::ReplyToSignRequest(
337 const std::string
& extension_id
,
339 const std::vector
<uint8_t>& signature
) {
340 DCHECK(thread_checker_
.CalledOnValidThread());
342 net::SSLPrivateKey::SignCallback callback
;
343 if (!sign_requests_
.RemoveRequest(extension_id
, sign_request_id
, &callback
)) {
344 LOG(ERROR
) << "request id unknown.";
345 // Maybe multiple replies to the same request.
349 const net::Error error_code
= signature
.empty() ? net::ERR_FAILED
: net::OK
;
350 callback
.Run(error_code
, signature
);
353 scoped_ptr
<CertificateProvider
>
354 CertificateProviderService::CreateCertificateProvider() {
355 DCHECK(thread_checker_
.CalledOnValidThread());
357 return make_scoped_ptr(new CertificateProviderImpl(
358 base::ThreadTaskRunnerHandle::Get(), weak_factory_
.GetWeakPtr()));
361 void CertificateProviderService::OnExtensionUnloaded(
362 const std::string
& extension_id
) {
363 DCHECK(thread_checker_
.CalledOnValidThread());
365 for (const int cert_request_id
:
366 certificate_requests_
.DropExtension(extension_id
)) {
367 std::map
<std::string
, CertificateInfoList
> certificates
;
368 base::Callback
<void(const net::CertificateList
&)> callback
;
369 certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
371 UpdateCertificatesAndRun(certificates
, callback
);
374 certificate_map_
.RemoveExtension(extension_id
);
376 for (auto callback
: sign_requests_
.RemoveAllRequests(extension_id
))
377 callback
.Run(net::ERR_FAILED
, std::vector
<uint8_t>());
380 void CertificateProviderService::GetCertificatesFromExtensions(
381 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
382 DCHECK(thread_checker_
.CalledOnValidThread());
384 const std::vector
<std::string
> provider_extensions(
385 delegate_
->CertificateProviderExtensions());
387 if (provider_extensions
.empty()) {
388 DVLOG(2) << "No provider extensions left, clear all certificates.";
389 UpdateCertificatesAndRun(std::map
<std::string
, CertificateInfoList
>(),
394 const int cert_request_id
= certificate_requests_
.AddRequest(
395 provider_extensions
, callback
,
396 base::Bind(&CertificateProviderService::TerminateCertificateRequest
,
397 base::Unretained(this)));
399 DVLOG(2) << "Start certificate request " << cert_request_id
;
400 delegate_
->BroadcastCertificateRequest(cert_request_id
);
403 void CertificateProviderService::UpdateCertificatesAndRun(
404 const std::map
<std::string
, CertificateInfoList
>& extension_to_certificates
,
405 const base::Callback
<void(const net::CertificateList
&)>& callback
) {
406 DCHECK(thread_checker_
.CalledOnValidThread());
408 // Extensions are removed from the service's state when they're unloaded.
409 // Any remaining extension is assumed to be enabled.
410 certificate_map_
.Update(extension_to_certificates
);
412 net::CertificateList all_certs
;
413 for (const auto& entry
: extension_to_certificates
) {
414 for (const CertificateInfo
& cert_info
: entry
.second
)
415 all_certs
.push_back(cert_info
.certificate
);
418 callback
.Run(all_certs
);
421 void CertificateProviderService::TerminateCertificateRequest(
422 int cert_request_id
) {
423 DCHECK(thread_checker_
.CalledOnValidThread());
425 std::map
<std::string
, CertificateInfoList
> certificates
;
426 base::Callback
<void(const net::CertificateList
&)> callback
;
427 if (!certificate_requests_
.RemoveRequest(cert_request_id
, &certificates
,
429 DLOG(WARNING
) << "Request id " << cert_request_id
<< " unknown.";
433 DVLOG(1) << "Time out certificate request " << cert_request_id
;
434 UpdateCertificatesAndRun(certificates
, callback
);
437 void CertificateProviderService::RequestSignatureFromExtension(
438 const std::string
& extension_id
,
439 const scoped_refptr
<net::X509Certificate
>& certificate
,
440 net::SSLPrivateKey::Hash hash
,
441 const std::string
& digest
,
442 const net::SSLPrivateKey::SignCallback
& callback
) {
443 DCHECK(thread_checker_
.CalledOnValidThread());
445 const int sign_request_id
= sign_requests_
.AddRequest(extension_id
, callback
);
446 if (!delegate_
->DispatchSignRequestToExtension(extension_id
, sign_request_id
,
447 hash
, certificate
, digest
)) {
448 sign_requests_
.RemoveRequest(extension_id
, sign_request_id
,
449 nullptr /* callback */);
450 callback
.Run(net::ERR_FAILED
, std::vector
<uint8_t>());
454 } // namespace chromeos