Add new certificateProvider extension API.
[chromium-blink-merge.git] / chrome / browser / chromeos / certificate_provider / certificate_provider_service.cc
blobda19c82e8540c127f5d652fe4c32a1752626ff2e
1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/chromeos/certificate_provider/certificate_provider_service.h"
7 #include "base/bind.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/stl_util.h"
13 #include "base/strings/string_piece.h"
14 #include "base/task_runner.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "chrome/browser/chromeos/certificate_provider/certificate_provider.h"
17 #include "net/base/net_errors.h"
18 #include "net/ssl/client_key_store.h"
20 namespace chromeos {
22 namespace {
24 void PostSignResultToTaskRunner(
25 const scoped_refptr<base::TaskRunner>& target_task_runner,
26 const net::SSLPrivateKey::SignCallback& callback,
27 net::Error error,
28 const std::vector<uint8_t>& signature) {
29 target_task_runner->PostTask(FROM_HERE,
30 base::Bind(callback, error, signature));
33 void PostCertificatesToTaskRunner(
34 const scoped_refptr<base::TaskRunner>& target_task_runner,
35 const base::Callback<void(const net::CertificateList&)>& callback,
36 const net::CertificateList& certs) {
37 target_task_runner->PostTask(FROM_HERE, base::Bind(callback, certs));
40 } // namespace
42 class CertificateProviderService::CertKeyProviderImpl
43 : public net::ClientKeyStore::CertKeyProvider {
44 public:
45 // |certificate_map| must outlive this provider. |service| must be
46 // dereferenceable on |service_task_runner|.
47 // This provider may be accessed from any thread. Methods and destructor must
48 // never be called concurrently.
49 CertKeyProviderImpl(
50 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
51 const base::WeakPtr<CertificateProviderService>& service,
52 certificate_provider::ThreadSafeCertificateMap* certificate_map);
53 ~CertKeyProviderImpl() override;
55 bool GetCertificateKey(const net::X509Certificate& cert,
56 scoped_ptr<net::SSLPrivateKey>* private_key) override;
58 private:
59 const scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
60 // Must be dereferenced on |service_task_runner_| only.
61 base::WeakPtr<CertificateProviderService> service_;
62 certificate_provider::ThreadSafeCertificateMap* const certificate_map_;
64 DISALLOW_COPY_AND_ASSIGN(CertKeyProviderImpl);
67 class CertificateProviderService::CertificateProviderImpl
68 : public CertificateProvider {
69 public:
70 // Any calls back to |service| will be posted to |service_task_runner|.
71 // |service| must be dereferenceable on |service_task_runner|.
72 // This provider is not thread safe, but can be used on any thread.
73 CertificateProviderImpl(
74 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
75 const base::WeakPtr<CertificateProviderService>& service);
76 ~CertificateProviderImpl() override;
78 void GetCertificates(const base::Callback<void(const net::CertificateList&)>&
79 callback) override;
81 scoped_ptr<CertificateProvider> Copy() override;
83 private:
84 static void GetCertificatesOnServiceThread(
85 const base::WeakPtr<CertificateProviderService>& service,
86 const base::Callback<void(const net::CertificateList&)>& callback);
88 const scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
89 // Must be dereferenced on |service_task_runner_| only.
90 const base::WeakPtr<CertificateProviderService> service_;
92 DISALLOW_COPY_AND_ASSIGN(CertificateProviderImpl);
95 // Implements an SSLPrivateKey backed by the signing function exposed by an
96 // extension through the certificateProvider API.
97 // Objects of this class must be used on a single thread. Any thread is allowed.
98 class CertificateProviderService::SSLPrivateKey : public net::SSLPrivateKey {
99 public:
100 // Any calls back to |service| will be posted to |service_task_runner|.
101 // |service| must be dereferenceable on |service_task_runner|.
102 SSLPrivateKey(
103 const std::string& extension_id,
104 const CertificateInfo& cert_info,
105 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
106 const base::WeakPtr<CertificateProviderService>& service);
107 ~SSLPrivateKey() override;
109 // net::SSLPrivateKey:
110 Type GetType() override;
111 bool SupportsHash(Hash hash) override;
112 size_t GetMaxSignatureLengthInBytes() override;
113 void SignDigest(Hash hash,
114 const base::StringPiece& input,
115 const SignCallback& callback) override;
117 private:
118 static void SignDigestOnServiceTaskRunner(
119 const base::WeakPtr<CertificateProviderService>& service,
120 const std::string& extension_id,
121 const scoped_refptr<net::X509Certificate>& certificate,
122 Hash hash,
123 const std::string& input,
124 const SignCallback& callback);
126 void DidSignDigest(const SignCallback& callback,
127 net::Error error,
128 const std::vector<uint8_t>& signature);
130 const std::string extension_id_;
131 const CertificateInfo cert_info_;
132 scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
133 // Must be dereferenced on |service_task_runner_| only.
134 const base::WeakPtr<CertificateProviderService> service_;
135 base::ThreadChecker thread_checker_;
136 base::WeakPtrFactory<SSLPrivateKey> weak_factory_;
138 DISALLOW_COPY_AND_ASSIGN(SSLPrivateKey);
141 CertificateProviderService::CertKeyProviderImpl::CertKeyProviderImpl(
142 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
143 const base::WeakPtr<CertificateProviderService>& service,
144 certificate_provider::ThreadSafeCertificateMap* certificate_map)
145 : service_task_runner_(service_task_runner),
146 service_(service),
147 certificate_map_(certificate_map) {}
149 CertificateProviderService::CertKeyProviderImpl::~CertKeyProviderImpl() {}
151 bool CertificateProviderService::CertKeyProviderImpl::GetCertificateKey(
152 const net::X509Certificate& cert,
153 scoped_ptr<net::SSLPrivateKey>* private_key) {
154 CertificateInfo info;
155 std::string extension_id;
156 if (!certificate_map_->LookUpCertificate(cert, &info, &extension_id))
157 return false;
159 private_key->reset(
160 new SSLPrivateKey(extension_id, info, service_task_runner_, service_));
161 return true;
164 CertificateProviderService::CertificateProviderImpl::CertificateProviderImpl(
165 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
166 const base::WeakPtr<CertificateProviderService>& service)
167 : service_task_runner_(service_task_runner), service_(service) {}
169 CertificateProviderService::CertificateProviderImpl::
170 ~CertificateProviderImpl() {}
172 void CertificateProviderService::CertificateProviderImpl::GetCertificates(
173 const base::Callback<void(const net::CertificateList&)>& callback) {
174 const scoped_refptr<base::TaskRunner> source_task_runner =
175 base::ThreadTaskRunnerHandle::Get();
176 const base::Callback<void(const net::CertificateList&)>
177 callback_from_service_thread = base::Bind(&PostCertificatesToTaskRunner,
178 source_task_runner, callback);
180 service_task_runner_->PostTask(
181 FROM_HERE, base::Bind(&GetCertificatesOnServiceThread, service_,
182 callback_from_service_thread));
185 scoped_ptr<CertificateProvider>
186 CertificateProviderService::CertificateProviderImpl::Copy() {
187 return make_scoped_ptr(
188 new CertificateProviderImpl(service_task_runner_, service_));
191 // static
192 void CertificateProviderService::CertificateProviderImpl::
193 GetCertificatesOnServiceThread(
194 const base::WeakPtr<CertificateProviderService>& service,
195 const base::Callback<void(const net::CertificateList&)>& callback) {
196 if (!service) {
197 callback.Run(net::CertificateList());
198 return;
200 service->GetCertificatesFromExtensions(callback);
203 CertificateProviderService::SSLPrivateKey::SSLPrivateKey(
204 const std::string& extension_id,
205 const CertificateInfo& cert_info,
206 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
207 const base::WeakPtr<CertificateProviderService>& service)
208 : extension_id_(extension_id),
209 cert_info_(cert_info),
210 service_task_runner_(service_task_runner),
211 service_(service),
212 weak_factory_(this) {
213 // This constructor is called on |service_task_runner|. Only subsequent calls
214 // to member functions have to be on a common thread.
215 thread_checker_.DetachFromThread();
218 CertificateProviderService::SSLPrivateKey::~SSLPrivateKey() {
219 DCHECK(thread_checker_.CalledOnValidThread());
222 CertificateProviderService::SSLPrivateKey::Type
223 CertificateProviderService::SSLPrivateKey::GetType() {
224 DCHECK(thread_checker_.CalledOnValidThread());
225 return cert_info_.type;
228 bool CertificateProviderService::SSLPrivateKey::SupportsHash(Hash hash) {
229 DCHECK(thread_checker_.CalledOnValidThread());
230 return ContainsValue(cert_info_.supported_hashes, hash);
233 size_t
234 CertificateProviderService::SSLPrivateKey::GetMaxSignatureLengthInBytes() {
235 DCHECK(thread_checker_.CalledOnValidThread());
236 return cert_info_.max_signature_length_in_bytes;
239 // static
240 void CertificateProviderService::SSLPrivateKey::SignDigestOnServiceTaskRunner(
241 const base::WeakPtr<CertificateProviderService>& service,
242 const std::string& extension_id,
243 const scoped_refptr<net::X509Certificate>& certificate,
244 Hash hash,
245 const std::string& input,
246 const SignCallback& callback) {
247 if (!service) {
248 const std::vector<uint8_t> no_signature;
249 callback.Run(net::ERR_FAILED, no_signature);
250 return;
252 service->RequestSignatureFromExtension(extension_id, certificate, hash, input,
253 callback);
256 void CertificateProviderService::SSLPrivateKey::SignDigest(
257 Hash hash,
258 const base::StringPiece& input,
259 const SignCallback& callback) {
260 DCHECK(thread_checker_.CalledOnValidThread());
261 const scoped_refptr<base::TaskRunner> source_task_runner =
262 base::ThreadTaskRunnerHandle::Get();
263 const SignCallback bound_callback =
264 // The CertificateProviderService calls back on another thread, so post
265 // back to the current thread.
266 base::Bind(&PostSignResultToTaskRunner, source_task_runner,
267 // Drop the result and don't call back if this key handle is
268 // destroyed in the meantime.
269 base::Bind(&SSLPrivateKey::DidSignDigest,
270 weak_factory_.GetWeakPtr(), callback));
272 service_task_runner_->PostTask(
273 FROM_HERE, base::Bind(&SSLPrivateKey::SignDigestOnServiceTaskRunner,
274 service_, extension_id_, cert_info_.certificate,
275 hash, input.as_string(), bound_callback));
278 void CertificateProviderService::SSLPrivateKey::DidSignDigest(
279 const SignCallback& callback,
280 net::Error error,
281 const std::vector<uint8_t>& signature) {
282 DCHECK(thread_checker_.CalledOnValidThread());
283 callback.Run(error, signature);
286 CertificateProviderService::CertificateProviderService()
287 : weak_factory_(this) {}
289 CertificateProviderService::~CertificateProviderService() {
290 DCHECK(thread_checker_.CalledOnValidThread());
292 // ClientKeyStore serializes access to |cert_key_provider_|.
293 // Once RemoveProvider() returns, it is guaranteed that there are no more
294 // accesses to |cert_key_provider_| in flight and no references to
295 // |cert_key_provider_| are remaining. This service will hold the last
296 // reference to |cert_key_provider_|.
297 net::ClientKeyStore::GetInstance()->RemoveProvider(cert_key_provider_.get());
298 cert_key_provider_.reset();
301 void CertificateProviderService::SetDelegate(scoped_ptr<Delegate> delegate) {
302 DCHECK(thread_checker_.CalledOnValidThread());
303 DCHECK(!delegate_);
304 DCHECK(delegate);
306 delegate_ = delegate.Pass();
307 cert_key_provider_.reset(
308 new CertKeyProviderImpl(base::ThreadTaskRunnerHandle::Get(),
309 weak_factory_.GetWeakPtr(), &certificate_map_));
310 net::ClientKeyStore::GetInstance()->AddProvider(cert_key_provider_.get());
313 bool CertificateProviderService::SetCertificatesProvidedByExtension(
314 const std::string& extension_id,
315 int cert_request_id,
316 const CertificateInfoList& certificate_infos) {
317 DCHECK(thread_checker_.CalledOnValidThread());
319 bool completed = false;
320 if (!certificate_requests_.SetCertificates(extension_id, cert_request_id,
321 certificate_infos, &completed)) {
322 DLOG(WARNING) << "Unexpected reply of extension " << extension_id
323 << " to request " << cert_request_id;
324 return false;
326 if (completed) {
327 std::map<std::string, CertificateInfoList> certificates;
328 base::Callback<void(const net::CertificateList&)> callback;
329 certificate_requests_.RemoveRequest(cert_request_id, &certificates,
330 &callback);
331 UpdateCertificatesAndRun(certificates, callback);
333 return true;
336 void CertificateProviderService::ReplyToSignRequest(
337 const std::string& extension_id,
338 int sign_request_id,
339 const std::vector<uint8_t>& signature) {
340 DCHECK(thread_checker_.CalledOnValidThread());
342 net::SSLPrivateKey::SignCallback callback;
343 if (!sign_requests_.RemoveRequest(extension_id, sign_request_id, &callback)) {
344 LOG(ERROR) << "request id unknown.";
345 // Maybe multiple replies to the same request.
346 return;
349 const net::Error error_code = signature.empty() ? net::ERR_FAILED : net::OK;
350 callback.Run(error_code, signature);
353 scoped_ptr<CertificateProvider>
354 CertificateProviderService::CreateCertificateProvider() {
355 DCHECK(thread_checker_.CalledOnValidThread());
357 return make_scoped_ptr(new CertificateProviderImpl(
358 base::ThreadTaskRunnerHandle::Get(), weak_factory_.GetWeakPtr()));
361 void CertificateProviderService::OnExtensionUnloaded(
362 const std::string& extension_id) {
363 DCHECK(thread_checker_.CalledOnValidThread());
365 for (const int cert_request_id :
366 certificate_requests_.DropExtension(extension_id)) {
367 std::map<std::string, CertificateInfoList> certificates;
368 base::Callback<void(const net::CertificateList&)> callback;
369 certificate_requests_.RemoveRequest(cert_request_id, &certificates,
370 &callback);
371 UpdateCertificatesAndRun(certificates, callback);
374 certificate_map_.RemoveExtension(extension_id);
376 for (auto callback : sign_requests_.RemoveAllRequests(extension_id))
377 callback.Run(net::ERR_FAILED, std::vector<uint8_t>());
380 void CertificateProviderService::GetCertificatesFromExtensions(
381 const base::Callback<void(const net::CertificateList&)>& callback) {
382 DCHECK(thread_checker_.CalledOnValidThread());
384 const std::vector<std::string> provider_extensions(
385 delegate_->CertificateProviderExtensions());
387 if (provider_extensions.empty()) {
388 DVLOG(2) << "No provider extensions left, clear all certificates.";
389 UpdateCertificatesAndRun(std::map<std::string, CertificateInfoList>(),
390 callback);
391 return;
394 const int cert_request_id = certificate_requests_.AddRequest(
395 provider_extensions, callback,
396 base::Bind(&CertificateProviderService::TerminateCertificateRequest,
397 base::Unretained(this)));
399 DVLOG(2) << "Start certificate request " << cert_request_id;
400 delegate_->BroadcastCertificateRequest(cert_request_id);
403 void CertificateProviderService::UpdateCertificatesAndRun(
404 const std::map<std::string, CertificateInfoList>& extension_to_certificates,
405 const base::Callback<void(const net::CertificateList&)>& callback) {
406 DCHECK(thread_checker_.CalledOnValidThread());
408 // Extensions are removed from the service's state when they're unloaded.
409 // Any remaining extension is assumed to be enabled.
410 certificate_map_.Update(extension_to_certificates);
412 net::CertificateList all_certs;
413 for (const auto& entry : extension_to_certificates) {
414 for (const CertificateInfo& cert_info : entry.second)
415 all_certs.push_back(cert_info.certificate);
418 callback.Run(all_certs);
421 void CertificateProviderService::TerminateCertificateRequest(
422 int cert_request_id) {
423 DCHECK(thread_checker_.CalledOnValidThread());
425 std::map<std::string, CertificateInfoList> certificates;
426 base::Callback<void(const net::CertificateList&)> callback;
427 if (!certificate_requests_.RemoveRequest(cert_request_id, &certificates,
428 &callback)) {
429 DLOG(WARNING) << "Request id " << cert_request_id << " unknown.";
430 return;
433 DVLOG(1) << "Time out certificate request " << cert_request_id;
434 UpdateCertificatesAndRun(certificates, callback);
437 void CertificateProviderService::RequestSignatureFromExtension(
438 const std::string& extension_id,
439 const scoped_refptr<net::X509Certificate>& certificate,
440 net::SSLPrivateKey::Hash hash,
441 const std::string& digest,
442 const net::SSLPrivateKey::SignCallback& callback) {
443 DCHECK(thread_checker_.CalledOnValidThread());
445 const int sign_request_id = sign_requests_.AddRequest(extension_id, callback);
446 if (!delegate_->DispatchSignRequestToExtension(extension_id, sign_request_id,
447 hash, certificate, digest)) {
448 sign_requests_.RemoveRequest(extension_id, sign_request_id,
449 nullptr /* callback */);
450 callback.Run(net::ERR_FAILED, std::vector<uint8_t>());
454 } // namespace chromeos