Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / chrome / browser / chromeos / certificate_provider / certificate_provider_service.cc
blobef0081b23fcaf8bc8614c1a73b441a2b0b9705aa
1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/chromeos/certificate_provider/certificate_provider_service.h"
7 #include "base/bind.h"
8 #include "base/bind_helpers.h"
9 #include "base/callback.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/stl_util.h"
13 #include "base/strings/string_piece.h"
14 #include "base/task_runner.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "chrome/browser/chromeos/certificate_provider/certificate_provider.h"
17 #include "net/base/net_errors.h"
18 #include "net/ssl/client_key_store.h"
20 namespace chromeos {
22 namespace {
24 void PostSignResultToTaskRunner(
25 const scoped_refptr<base::TaskRunner>& target_task_runner,
26 const net::SSLPrivateKey::SignCallback& callback,
27 net::Error error,
28 const std::vector<uint8_t>& signature) {
29 target_task_runner->PostTask(FROM_HERE,
30 base::Bind(callback, error, signature));
33 void PostCertificatesToTaskRunner(
34 const scoped_refptr<base::TaskRunner>& target_task_runner,
35 const base::Callback<void(const net::CertificateList&)>& callback,
36 const net::CertificateList& certs) {
37 target_task_runner->PostTask(FROM_HERE, base::Bind(callback, certs));
40 } // namespace
42 class CertificateProviderService::CertKeyProviderImpl
43 : public net::ClientKeyStore::CertKeyProvider {
44 public:
45 // |certificate_map| must outlive this provider. |service| must be
46 // dereferenceable on |service_task_runner|.
47 // This provider may be accessed from any thread. Methods and destructor must
48 // never be called concurrently.
49 CertKeyProviderImpl(
50 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
51 const base::WeakPtr<CertificateProviderService>& service,
52 certificate_provider::ThreadSafeCertificateMap* certificate_map);
53 ~CertKeyProviderImpl() override;
55 bool GetCertificateKey(const net::X509Certificate& cert,
56 scoped_ptr<net::SSLPrivateKey>* private_key) override;
58 private:
59 const scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
60 // Must be dereferenced on |service_task_runner_| only.
61 base::WeakPtr<CertificateProviderService> service_;
62 certificate_provider::ThreadSafeCertificateMap* const certificate_map_;
64 DISALLOW_COPY_AND_ASSIGN(CertKeyProviderImpl);
67 class CertificateProviderService::CertificateProviderImpl
68 : public CertificateProvider {
69 public:
70 // Any calls back to |service| will be posted to |service_task_runner|.
71 // |service| must be dereferenceable on |service_task_runner|.
72 // This provider is not thread safe, but can be used on any thread.
73 CertificateProviderImpl(
74 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
75 const base::WeakPtr<CertificateProviderService>& service);
76 ~CertificateProviderImpl() override;
78 void GetCertificates(const base::Callback<void(const net::CertificateList&)>&
79 callback) override;
81 scoped_ptr<CertificateProvider> Copy() override;
83 private:
84 static void GetCertificatesOnServiceThread(
85 const base::WeakPtr<CertificateProviderService>& service,
86 const base::Callback<void(const net::CertificateList&)>& callback);
88 const scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
89 // Must be dereferenced on |service_task_runner_| only.
90 const base::WeakPtr<CertificateProviderService> service_;
92 DISALLOW_COPY_AND_ASSIGN(CertificateProviderImpl);
95 // Implements an SSLPrivateKey backed by the signing function exposed by an
96 // extension through the certificateProvider API.
97 // Objects of this class must be used on a single thread. Any thread is allowed.
98 class CertificateProviderService::SSLPrivateKey : public net::SSLPrivateKey {
99 public:
100 // Any calls back to |service| will be posted to |service_task_runner|.
101 // |service| must be dereferenceable on |service_task_runner|.
102 SSLPrivateKey(
103 const std::string& extension_id,
104 const CertificateInfo& cert_info,
105 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
106 const base::WeakPtr<CertificateProviderService>& service);
107 ~SSLPrivateKey() override;
109 // net::SSLPrivateKey:
110 Type GetType() override;
111 bool SupportsHash(Hash hash) override;
112 size_t GetMaxSignatureLengthInBytes() override;
113 void SignDigest(Hash hash,
114 const base::StringPiece& input,
115 const SignCallback& callback) override;
117 private:
118 static void SignDigestOnServiceTaskRunner(
119 const base::WeakPtr<CertificateProviderService>& service,
120 const std::string& extension_id,
121 const scoped_refptr<net::X509Certificate>& certificate,
122 Hash hash,
123 const std::string& input,
124 const SignCallback& callback);
126 void DidSignDigest(const SignCallback& callback,
127 net::Error error,
128 const std::vector<uint8_t>& signature);
130 const std::string extension_id_;
131 const CertificateInfo cert_info_;
132 scoped_refptr<base::SequencedTaskRunner> service_task_runner_;
133 // Must be dereferenced on |service_task_runner_| only.
134 const base::WeakPtr<CertificateProviderService> service_;
135 base::ThreadChecker thread_checker_;
136 base::WeakPtrFactory<SSLPrivateKey> weak_factory_;
138 DISALLOW_COPY_AND_ASSIGN(SSLPrivateKey);
141 CertificateProviderService::CertKeyProviderImpl::CertKeyProviderImpl(
142 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
143 const base::WeakPtr<CertificateProviderService>& service,
144 certificate_provider::ThreadSafeCertificateMap* certificate_map)
145 : service_task_runner_(service_task_runner),
146 service_(service),
147 certificate_map_(certificate_map) {}
149 CertificateProviderService::CertKeyProviderImpl::~CertKeyProviderImpl() {}
151 bool CertificateProviderService::CertKeyProviderImpl::GetCertificateKey(
152 const net::X509Certificate& cert,
153 scoped_ptr<net::SSLPrivateKey>* private_key) {
154 bool is_currently_provided = false;
155 CertificateInfo info;
156 std::string extension_id;
157 certificate_map_->LookUpCertificate(cert, &is_currently_provided, &info,
158 &extension_id);
159 if (!is_currently_provided)
160 return false;
162 private_key->reset(
163 new SSLPrivateKey(extension_id, info, service_task_runner_, service_));
164 return true;
167 CertificateProviderService::CertificateProviderImpl::CertificateProviderImpl(
168 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
169 const base::WeakPtr<CertificateProviderService>& service)
170 : service_task_runner_(service_task_runner), service_(service) {}
172 CertificateProviderService::CertificateProviderImpl::
173 ~CertificateProviderImpl() {}
175 void CertificateProviderService::CertificateProviderImpl::GetCertificates(
176 const base::Callback<void(const net::CertificateList&)>& callback) {
177 const scoped_refptr<base::TaskRunner> source_task_runner =
178 base::ThreadTaskRunnerHandle::Get();
179 const base::Callback<void(const net::CertificateList&)>
180 callback_from_service_thread = base::Bind(&PostCertificatesToTaskRunner,
181 source_task_runner, callback);
183 service_task_runner_->PostTask(
184 FROM_HERE, base::Bind(&GetCertificatesOnServiceThread, service_,
185 callback_from_service_thread));
188 scoped_ptr<CertificateProvider>
189 CertificateProviderService::CertificateProviderImpl::Copy() {
190 return make_scoped_ptr(
191 new CertificateProviderImpl(service_task_runner_, service_));
194 // static
195 void CertificateProviderService::CertificateProviderImpl::
196 GetCertificatesOnServiceThread(
197 const base::WeakPtr<CertificateProviderService>& service,
198 const base::Callback<void(const net::CertificateList&)>& callback) {
199 if (!service) {
200 callback.Run(net::CertificateList());
201 return;
203 service->GetCertificatesFromExtensions(callback);
206 CertificateProviderService::SSLPrivateKey::SSLPrivateKey(
207 const std::string& extension_id,
208 const CertificateInfo& cert_info,
209 const scoped_refptr<base::SequencedTaskRunner>& service_task_runner,
210 const base::WeakPtr<CertificateProviderService>& service)
211 : extension_id_(extension_id),
212 cert_info_(cert_info),
213 service_task_runner_(service_task_runner),
214 service_(service),
215 weak_factory_(this) {
216 // This constructor is called on |service_task_runner|. Only subsequent calls
217 // to member functions have to be on a common thread.
218 thread_checker_.DetachFromThread();
221 CertificateProviderService::SSLPrivateKey::~SSLPrivateKey() {
222 DCHECK(thread_checker_.CalledOnValidThread());
225 CertificateProviderService::SSLPrivateKey::Type
226 CertificateProviderService::SSLPrivateKey::GetType() {
227 DCHECK(thread_checker_.CalledOnValidThread());
228 return cert_info_.type;
231 bool CertificateProviderService::SSLPrivateKey::SupportsHash(Hash hash) {
232 DCHECK(thread_checker_.CalledOnValidThread());
233 return ContainsValue(cert_info_.supported_hashes, hash);
236 size_t
237 CertificateProviderService::SSLPrivateKey::GetMaxSignatureLengthInBytes() {
238 DCHECK(thread_checker_.CalledOnValidThread());
239 return cert_info_.max_signature_length_in_bytes;
242 // static
243 void CertificateProviderService::SSLPrivateKey::SignDigestOnServiceTaskRunner(
244 const base::WeakPtr<CertificateProviderService>& service,
245 const std::string& extension_id,
246 const scoped_refptr<net::X509Certificate>& certificate,
247 Hash hash,
248 const std::string& input,
249 const SignCallback& callback) {
250 if (!service) {
251 const std::vector<uint8_t> no_signature;
252 callback.Run(net::ERR_FAILED, no_signature);
253 return;
255 service->RequestSignatureFromExtension(extension_id, certificate, hash, input,
256 callback);
259 void CertificateProviderService::SSLPrivateKey::SignDigest(
260 Hash hash,
261 const base::StringPiece& input,
262 const SignCallback& callback) {
263 DCHECK(thread_checker_.CalledOnValidThread());
264 const scoped_refptr<base::TaskRunner> source_task_runner =
265 base::ThreadTaskRunnerHandle::Get();
266 const SignCallback bound_callback =
267 // The CertificateProviderService calls back on another thread, so post
268 // back to the current thread.
269 base::Bind(&PostSignResultToTaskRunner, source_task_runner,
270 // Drop the result and don't call back if this key handle is
271 // destroyed in the meantime.
272 base::Bind(&SSLPrivateKey::DidSignDigest,
273 weak_factory_.GetWeakPtr(), callback));
275 service_task_runner_->PostTask(
276 FROM_HERE, base::Bind(&SSLPrivateKey::SignDigestOnServiceTaskRunner,
277 service_, extension_id_, cert_info_.certificate,
278 hash, input.as_string(), bound_callback));
281 void CertificateProviderService::SSLPrivateKey::DidSignDigest(
282 const SignCallback& callback,
283 net::Error error,
284 const std::vector<uint8_t>& signature) {
285 DCHECK(thread_checker_.CalledOnValidThread());
286 callback.Run(error, signature);
289 CertificateProviderService::CertificateProviderService()
290 : weak_factory_(this) {}
292 CertificateProviderService::~CertificateProviderService() {
293 DCHECK(thread_checker_.CalledOnValidThread());
295 // ClientKeyStore serializes access to |cert_key_provider_|.
296 // Once RemoveProvider() returns, it is guaranteed that there are no more
297 // accesses to |cert_key_provider_| in flight and no references to
298 // |cert_key_provider_| are remaining. This service will hold the last
299 // reference to |cert_key_provider_|.
300 net::ClientKeyStore::GetInstance()->RemoveProvider(cert_key_provider_.get());
301 cert_key_provider_.reset();
304 void CertificateProviderService::SetDelegate(scoped_ptr<Delegate> delegate) {
305 DCHECK(thread_checker_.CalledOnValidThread());
306 DCHECK(!delegate_);
307 DCHECK(delegate);
309 delegate_ = delegate.Pass();
310 cert_key_provider_.reset(
311 new CertKeyProviderImpl(base::ThreadTaskRunnerHandle::Get(),
312 weak_factory_.GetWeakPtr(), &certificate_map_));
313 net::ClientKeyStore::GetInstance()->AddProvider(cert_key_provider_.get());
316 bool CertificateProviderService::SetCertificatesProvidedByExtension(
317 const std::string& extension_id,
318 int cert_request_id,
319 const CertificateInfoList& certificate_infos) {
320 DCHECK(thread_checker_.CalledOnValidThread());
322 bool completed = false;
323 if (!certificate_requests_.SetCertificates(extension_id, cert_request_id,
324 certificate_infos, &completed)) {
325 DLOG(WARNING) << "Unexpected reply of extension " << extension_id
326 << " to request " << cert_request_id;
327 return false;
329 if (completed) {
330 std::map<std::string, CertificateInfoList> certificates;
331 base::Callback<void(const net::CertificateList&)> callback;
332 certificate_requests_.RemoveRequest(cert_request_id, &certificates,
333 &callback);
334 UpdateCertificatesAndRun(certificates, callback);
336 return true;
339 void CertificateProviderService::ReplyToSignRequest(
340 const std::string& extension_id,
341 int sign_request_id,
342 const std::vector<uint8_t>& signature) {
343 DCHECK(thread_checker_.CalledOnValidThread());
345 net::SSLPrivateKey::SignCallback callback;
346 if (!sign_requests_.RemoveRequest(extension_id, sign_request_id, &callback)) {
347 LOG(ERROR) << "request id unknown.";
348 // Maybe multiple replies to the same request.
349 return;
352 const net::Error error_code = signature.empty() ? net::ERR_FAILED : net::OK;
353 callback.Run(error_code, signature);
356 bool CertificateProviderService::LookUpCertificate(
357 const net::X509Certificate& cert,
358 bool* has_extension,
359 std::string* extension_id) {
360 DCHECK(thread_checker_.CalledOnValidThread());
362 CertificateInfo unused_info;
363 return certificate_map_.LookUpCertificate(cert, has_extension, &unused_info,
364 extension_id);
367 scoped_ptr<CertificateProvider>
368 CertificateProviderService::CreateCertificateProvider() {
369 DCHECK(thread_checker_.CalledOnValidThread());
371 return make_scoped_ptr(new CertificateProviderImpl(
372 base::ThreadTaskRunnerHandle::Get(), weak_factory_.GetWeakPtr()));
375 void CertificateProviderService::OnExtensionUnloaded(
376 const std::string& extension_id) {
377 DCHECK(thread_checker_.CalledOnValidThread());
379 for (const int cert_request_id :
380 certificate_requests_.DropExtension(extension_id)) {
381 std::map<std::string, CertificateInfoList> certificates;
382 base::Callback<void(const net::CertificateList&)> callback;
383 certificate_requests_.RemoveRequest(cert_request_id, &certificates,
384 &callback);
385 UpdateCertificatesAndRun(certificates, callback);
388 certificate_map_.RemoveExtension(extension_id);
390 for (auto callback : sign_requests_.RemoveAllRequests(extension_id))
391 callback.Run(net::ERR_FAILED, std::vector<uint8_t>());
394 void CertificateProviderService::GetCertificatesFromExtensions(
395 const base::Callback<void(const net::CertificateList&)>& callback) {
396 DCHECK(thread_checker_.CalledOnValidThread());
398 const std::vector<std::string> provider_extensions(
399 delegate_->CertificateProviderExtensions());
401 if (provider_extensions.empty()) {
402 DVLOG(2) << "No provider extensions left, clear all certificates.";
403 UpdateCertificatesAndRun(std::map<std::string, CertificateInfoList>(),
404 callback);
405 return;
408 const int cert_request_id = certificate_requests_.AddRequest(
409 provider_extensions, callback,
410 base::Bind(&CertificateProviderService::TerminateCertificateRequest,
411 base::Unretained(this)));
413 DVLOG(2) << "Start certificate request " << cert_request_id;
414 delegate_->BroadcastCertificateRequest(cert_request_id);
417 void CertificateProviderService::UpdateCertificatesAndRun(
418 const std::map<std::string, CertificateInfoList>& extension_to_certificates,
419 const base::Callback<void(const net::CertificateList&)>& callback) {
420 DCHECK(thread_checker_.CalledOnValidThread());
422 // Extensions are removed from the service's state when they're unloaded.
423 // Any remaining extension is assumed to be enabled.
424 certificate_map_.Update(extension_to_certificates);
426 net::CertificateList all_certs;
427 for (const auto& entry : extension_to_certificates) {
428 for (const CertificateInfo& cert_info : entry.second)
429 all_certs.push_back(cert_info.certificate);
432 callback.Run(all_certs);
435 void CertificateProviderService::TerminateCertificateRequest(
436 int cert_request_id) {
437 DCHECK(thread_checker_.CalledOnValidThread());
439 std::map<std::string, CertificateInfoList> certificates;
440 base::Callback<void(const net::CertificateList&)> callback;
441 if (!certificate_requests_.RemoveRequest(cert_request_id, &certificates,
442 &callback)) {
443 DLOG(WARNING) << "Request id " << cert_request_id << " unknown.";
444 return;
447 DVLOG(1) << "Time out certificate request " << cert_request_id;
448 UpdateCertificatesAndRun(certificates, callback);
451 void CertificateProviderService::RequestSignatureFromExtension(
452 const std::string& extension_id,
453 const scoped_refptr<net::X509Certificate>& certificate,
454 net::SSLPrivateKey::Hash hash,
455 const std::string& digest,
456 const net::SSLPrivateKey::SignCallback& callback) {
457 DCHECK(thread_checker_.CalledOnValidThread());
459 const int sign_request_id = sign_requests_.AddRequest(extension_id, callback);
460 if (!delegate_->DispatchSignRequestToExtension(extension_id, sign_request_id,
461 hash, certificate, digest)) {
462 sign_requests_.RemoveRequest(extension_id, sign_request_id,
463 nullptr /* callback */);
464 callback.Run(net::ERR_FAILED, std::vector<uint8_t>());
468 } // namespace chromeos