1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "media/cdm/json_web_key.h"
7 #include "base/base64.h"
8 #include "base/json/json_reader.h"
9 #include "base/json/json_string_value_serializer.h"
10 #include "base/json/string_escape.h"
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "base/strings/string_util.h"
15 #include "base/values.h"
19 const char kKeysTag
[] = "keys";
20 const char kKeyTypeTag
[] = "kty";
21 const char kKeyTypeOct
[] = "oct"; // Octet sequence.
22 const char kKeyTag
[] = "k";
23 const char kKeyIdTag
[] = "kid";
24 const char kKeyIdsTag
[] = "kids";
25 const char kBase64Padding
= '=';
26 const char kBase64Plus
[] = "+";
27 const char kBase64UrlPlusReplacement
[] = "-";
28 const char kBase64Slash
[] = "/";
29 const char kBase64UrlSlashReplacement
[] = "_";
30 const char kBase64UrlInvalid
[] = "+/=";
31 const char kTypeTag
[] = "type";
32 const char kTemporarySession
[] = "temporary";
33 const char kPersistentLicenseSession
[] = "persistent-license";
34 const char kPersistentReleaseMessageSession
[] = "persistent-release-message";
36 // Encodes |input| into a base64url string without padding.
37 static std::string
EncodeBase64Url(const uint8
* input
, int input_length
) {
38 std::string encoded_text
;
40 std::string(reinterpret_cast<const char*>(input
), input_length
),
43 // Remove any padding characters added by Base64Encode().
44 size_t found
= encoded_text
.find_last_not_of(kBase64Padding
);
45 if (found
!= std::string::npos
)
46 encoded_text
.erase(found
+ 1);
48 // base64url encoding means the characters '-' and '_' must be used
49 // instead of '+' and '/', respectively.
50 base::ReplaceChars(encoded_text
, kBase64Plus
, kBase64UrlPlusReplacement
,
52 base::ReplaceChars(encoded_text
, kBase64Slash
, kBase64UrlSlashReplacement
,
58 // Decodes a base64url string. Returns empty string on error.
59 static std::string
DecodeBase64Url(const std::string
& encoded_text
) {
60 // EME spec doesn't allow '+', '/', or padding characters.
61 if (encoded_text
.find_first_of(kBase64UrlInvalid
) != std::string::npos
) {
62 DVLOG(1) << "Invalid base64url format: " << encoded_text
;
66 // Since base::Base64Decode() requires padding characters, add them so length
67 // of |encoded_text| is exactly a multiple of 4.
68 size_t num_last_grouping_chars
= encoded_text
.length() % 4;
69 std::string modified_text
= encoded_text
;
70 if (num_last_grouping_chars
> 0)
71 modified_text
.append(4 - num_last_grouping_chars
, kBase64Padding
);
73 // base64url encoding means the characters '-' and '_' must be used
74 // instead of '+' and '/', respectively, so replace them before calling
75 // base::Base64Decode().
76 base::ReplaceChars(modified_text
, kBase64UrlPlusReplacement
, kBase64Plus
,
78 base::ReplaceChars(modified_text
, kBase64UrlSlashReplacement
, kBase64Slash
,
81 std::string decoded_text
;
82 if (!base::Base64Decode(modified_text
, &decoded_text
)) {
83 DVLOG(1) << "Base64 decoding failed on: " << modified_text
;
90 static std::string
ShortenTo64Characters(const std::string
& input
) {
91 // Convert |input| into a string with escaped characters replacing any
92 // non-ASCII characters. Limiting |input| to the first 65 characters so
93 // we don't waste time converting a potentially long string and then
94 // throwing away the excess.
95 std::string escaped_str
=
96 base::EscapeBytesAsInvalidJSONString(input
.substr(0, 65), false);
97 if (escaped_str
.length() <= 64u)
100 // This may end up truncating an escaped character, but the first part of
101 // the string should provide enough information.
102 return escaped_str
.substr(0, 61).append("...");
105 static scoped_ptr
<base::DictionaryValue
> CreateJSONDictionary(
110 scoped_ptr
<base::DictionaryValue
> jwk(new base::DictionaryValue());
111 jwk
->SetString(kKeyTypeTag
, kKeyTypeOct
);
112 jwk
->SetString(kKeyTag
, EncodeBase64Url(key
, key_length
));
113 jwk
->SetString(kKeyIdTag
, EncodeBase64Url(key_id
, key_id_length
));
117 std::string
GenerateJWKSet(const uint8
* key
, int key_length
,
118 const uint8
* key_id
, int key_id_length
) {
119 // Create the JWK, and wrap it into a JWK Set.
120 scoped_ptr
<base::ListValue
> list(new base::ListValue());
122 CreateJSONDictionary(key
, key_length
, key_id
, key_id_length
).release());
123 base::DictionaryValue jwk_set
;
124 jwk_set
.Set(kKeysTag
, list
.release());
126 // Finally serialize |jwk_set| into a string and return it.
127 std::string serialized_jwk
;
128 JSONStringValueSerializer
serializer(&serialized_jwk
);
129 serializer
.Serialize(jwk_set
);
130 return serialized_jwk
;
133 std::string
GenerateJWKSet(const KeyIdAndKeyPairs
& keys
,
134 MediaKeys::SessionType session_type
) {
135 scoped_ptr
<base::ListValue
> list(new base::ListValue());
136 for (const auto& key_pair
: keys
) {
137 list
->Append(CreateJSONDictionary(
138 reinterpret_cast<const uint8
*>(key_pair
.second
.data()),
139 key_pair
.second
.length(),
140 reinterpret_cast<const uint8
*>(key_pair
.first
.data()),
141 key_pair
.first
.length())
145 base::DictionaryValue jwk_set
;
146 jwk_set
.Set(kKeysTag
, list
.release());
147 switch (session_type
) {
148 case MediaKeys::TEMPORARY_SESSION
:
149 jwk_set
.SetString(kTypeTag
, kTemporarySession
);
151 case MediaKeys::PERSISTENT_LICENSE_SESSION
:
152 jwk_set
.SetString(kTypeTag
, kPersistentLicenseSession
);
154 case MediaKeys::PERSISTENT_RELEASE_MESSAGE_SESSION
:
155 jwk_set
.SetString(kTypeTag
, kPersistentReleaseMessageSession
);
159 // Finally serialize |jwk_set| into a string and return it.
160 std::string serialized_jwk
;
161 JSONStringValueSerializer
serializer(&serialized_jwk
);
162 serializer
.Serialize(jwk_set
);
163 return serialized_jwk
;
166 // Processes a JSON Web Key to extract the key id and key value. Sets |jwk_key|
167 // to the id/value pair and returns true on success.
168 static bool ConvertJwkToKeyPair(const base::DictionaryValue
& jwk
,
169 KeyIdAndKeyPair
* jwk_key
) {
171 if (!jwk
.GetString(kKeyTypeTag
, &type
) || type
!= kKeyTypeOct
) {
172 DVLOG(1) << "Missing or invalid '" << kKeyTypeTag
<< "': " << type
;
176 // Get the key id and actual key parameters.
177 std::string encoded_key_id
;
178 std::string encoded_key
;
179 if (!jwk
.GetString(kKeyIdTag
, &encoded_key_id
)) {
180 DVLOG(1) << "Missing '" << kKeyIdTag
<< "' parameter";
183 if (!jwk
.GetString(kKeyTag
, &encoded_key
)) {
184 DVLOG(1) << "Missing '" << kKeyTag
<< "' parameter";
188 // Key ID and key are base64-encoded strings, so decode them.
189 std::string raw_key_id
= DecodeBase64Url(encoded_key_id
);
190 if (raw_key_id
.empty()) {
191 DVLOG(1) << "Invalid '" << kKeyIdTag
<< "' value: " << encoded_key_id
;
195 std::string raw_key
= DecodeBase64Url(encoded_key
);
196 if (raw_key
.empty()) {
197 DVLOG(1) << "Invalid '" << kKeyTag
<< "' value: " << encoded_key
;
201 // Add the decoded key ID and the decoded key to the list.
202 *jwk_key
= std::make_pair(raw_key_id
, raw_key
);
206 bool ExtractKeysFromJWKSet(const std::string
& jwk_set
,
207 KeyIdAndKeyPairs
* keys
,
208 MediaKeys::SessionType
* session_type
) {
209 if (!base::IsStringASCII(jwk_set
)) {
210 DVLOG(1) << "Non ASCII JWK Set: " << jwk_set
;
214 scoped_ptr
<base::Value
> root(base::JSONReader().ReadToValue(jwk_set
));
215 if (!root
.get() || root
->GetType() != base::Value::TYPE_DICTIONARY
) {
216 DVLOG(1) << "Not valid JSON: " << jwk_set
<< ", root: " << root
.get();
220 // Locate the set from the dictionary.
221 base::DictionaryValue
* dictionary
=
222 static_cast<base::DictionaryValue
*>(root
.get());
223 base::ListValue
* list_val
= NULL
;
224 if (!dictionary
->GetList(kKeysTag
, &list_val
)) {
225 DVLOG(1) << "Missing '" << kKeysTag
226 << "' parameter or not a list in JWK Set";
230 // Create a local list of keys, so that |jwk_keys| only gets updated on
232 KeyIdAndKeyPairs local_keys
;
233 for (size_t i
= 0; i
< list_val
->GetSize(); ++i
) {
234 base::DictionaryValue
* jwk
= NULL
;
235 if (!list_val
->GetDictionary(i
, &jwk
)) {
236 DVLOG(1) << "Unable to access '" << kKeysTag
<< "'[" << i
240 KeyIdAndKeyPair key_pair
;
241 if (!ConvertJwkToKeyPair(*jwk
, &key_pair
)) {
242 DVLOG(1) << "Error from '" << kKeysTag
<< "'[" << i
<< "]";
245 local_keys
.push_back(key_pair
);
248 // Successfully processed all JWKs in the set. Now check if "type" is
250 base::Value
* value
= NULL
;
251 std::string session_type_id
;
252 if (!dictionary
->Get(kTypeTag
, &value
)) {
253 // Not specified, so use the default type.
254 *session_type
= MediaKeys::TEMPORARY_SESSION
;
255 } else if (!value
->GetAsString(&session_type_id
)) {
256 DVLOG(1) << "Invalid '" << kTypeTag
<< "' value";
258 } else if (session_type_id
== kTemporarySession
) {
259 *session_type
= MediaKeys::TEMPORARY_SESSION
;
260 } else if (session_type_id
== kPersistentLicenseSession
) {
261 *session_type
= MediaKeys::PERSISTENT_LICENSE_SESSION
;
262 } else if (session_type_id
== kPersistentReleaseMessageSession
) {
263 *session_type
= MediaKeys::PERSISTENT_RELEASE_MESSAGE_SESSION
;
265 DVLOG(1) << "Invalid '" << kTypeTag
<< "' value: " << session_type_id
;
270 keys
->swap(local_keys
);
274 bool ExtractKeyIdsFromKeyIdsInitData(const std::string
& input
,
276 std::string
* error_message
) {
277 if (!base::IsStringASCII(input
)) {
278 error_message
->assign("Non ASCII: ");
279 error_message
->append(ShortenTo64Characters(input
));
283 scoped_ptr
<base::Value
> root(base::JSONReader().ReadToValue(input
));
284 if (!root
.get() || root
->GetType() != base::Value::TYPE_DICTIONARY
) {
285 error_message
->assign("Not valid JSON: ");
286 error_message
->append(ShortenTo64Characters(input
));
290 // Locate the set from the dictionary.
291 base::DictionaryValue
* dictionary
=
292 static_cast<base::DictionaryValue
*>(root
.get());
293 base::ListValue
* list_val
= NULL
;
294 if (!dictionary
->GetList(kKeyIdsTag
, &list_val
)) {
295 error_message
->assign("Missing '");
296 error_message
->append(kKeyIdsTag
);
297 error_message
->append("' parameter or not a list");
301 // Create a local list of key ids, so that |key_ids| only gets updated on
303 KeyIdList local_key_ids
;
304 for (size_t i
= 0; i
< list_val
->GetSize(); ++i
) {
305 std::string encoded_key_id
;
306 if (!list_val
->GetString(i
, &encoded_key_id
)) {
307 error_message
->assign("'");
308 error_message
->append(kKeyIdsTag
);
309 error_message
->append("'[");
310 error_message
->append(base::UintToString(i
));
311 error_message
->append("] is not string.");
315 // Key ID is a base64-encoded string, so decode it.
316 std::string raw_key_id
= DecodeBase64Url(encoded_key_id
);
317 if (raw_key_id
.empty()) {
318 error_message
->assign("'");
319 error_message
->append(kKeyIdsTag
);
320 error_message
->append("'[");
321 error_message
->append(base::UintToString(i
));
322 error_message
->append("] is not valid base64url encoded. Value: ");
323 error_message
->append(ShortenTo64Characters(encoded_key_id
));
327 // Add the decoded key ID to the list.
328 local_key_ids
.push_back(std::vector
<uint8
>(
329 raw_key_id
.data(), raw_key_id
.data() + raw_key_id
.length()));
333 key_ids
->swap(local_key_ids
);
334 error_message
->clear();
338 void CreateLicenseRequest(const KeyIdList
& key_ids
,
339 MediaKeys::SessionType session_type
,
340 std::vector
<uint8
>* license
) {
341 // Create the license request.
342 scoped_ptr
<base::DictionaryValue
> request(new base::DictionaryValue());
343 scoped_ptr
<base::ListValue
> list(new base::ListValue());
344 for (const auto& key_id
: key_ids
)
345 list
->AppendString(EncodeBase64Url(&key_id
[0], key_id
.size()));
346 request
->Set(kKeyIdsTag
, list
.release());
348 switch (session_type
) {
349 case MediaKeys::TEMPORARY_SESSION
:
350 request
->SetString(kTypeTag
, kTemporarySession
);
352 case MediaKeys::PERSISTENT_LICENSE_SESSION
:
353 request
->SetString(kTypeTag
, kPersistentLicenseSession
);
355 case MediaKeys::PERSISTENT_RELEASE_MESSAGE_SESSION
:
356 request
->SetString(kTypeTag
, kPersistentReleaseMessageSession
);
360 // Serialize the license request as a string.
362 JSONStringValueSerializer
serializer(&json
);
363 serializer
.Serialize(*request
);
365 // Convert the serialized license request into std::vector and return it.
366 std::vector
<uint8
> result(json
.begin(), json
.end());
367 license
->swap(result
);
370 void CreateKeyIdsInitData(const KeyIdList
& key_ids
,
371 std::vector
<uint8
>* init_data
) {
372 // Create the init_data.
373 scoped_ptr
<base::DictionaryValue
> dictionary(new base::DictionaryValue());
374 scoped_ptr
<base::ListValue
> list(new base::ListValue());
375 for (const auto& key_id
: key_ids
)
376 list
->AppendString(EncodeBase64Url(&key_id
[0], key_id
.size()));
377 dictionary
->Set(kKeyIdsTag
, list
.release());
379 // Serialize the dictionary as a string.
381 JSONStringValueSerializer
serializer(&json
);
382 serializer
.Serialize(*dictionary
);
384 // Convert the serialized data into std::vector and return it.
385 std::vector
<uint8
> result(json
.begin(), json
.end());
386 init_data
->swap(result
);
389 bool ExtractFirstKeyIdFromLicenseRequest(const std::vector
<uint8
>& license
,
390 std::vector
<uint8
>* first_key
) {
391 const std::string
license_as_str(
392 reinterpret_cast<const char*>(!license
.empty() ? &license
[0] : NULL
),
394 if (!base::IsStringASCII(license_as_str
)) {
395 DVLOG(1) << "Non ASCII license: " << license_as_str
;
399 scoped_ptr
<base::Value
> root(base::JSONReader().ReadToValue(license_as_str
));
400 if (!root
.get() || root
->GetType() != base::Value::TYPE_DICTIONARY
) {
401 DVLOG(1) << "Not valid JSON: " << license_as_str
;
405 // Locate the set from the dictionary.
406 base::DictionaryValue
* dictionary
=
407 static_cast<base::DictionaryValue
*>(root
.get());
408 base::ListValue
* list_val
= NULL
;
409 if (!dictionary
->GetList(kKeyIdsTag
, &list_val
)) {
410 DVLOG(1) << "Missing '" << kKeyIdsTag
<< "' parameter or not a list";
414 // Get the first key.
415 if (list_val
->GetSize() < 1) {
416 DVLOG(1) << "Empty '" << kKeyIdsTag
<< "' list";
420 std::string encoded_key
;
421 if (!list_val
->GetString(0, &encoded_key
)) {
422 DVLOG(1) << "First entry in '" << kKeyIdsTag
<< "' not a string";
426 std::string decoded_string
= DecodeBase64Url(encoded_key
);
427 if (decoded_string
.empty()) {
428 DVLOG(1) << "Invalid '" << kKeyIdsTag
<< "' value: " << encoded_key
;
432 std::vector
<uint8
> result(decoded_string
.begin(), decoded_string
.end());
433 first_key
->swap(result
);