[ie/youtube] Add age-gate workaround for some embeddable videos (#11821)
[yt-dlp.git] / yt_dlp / utils / traversal.py
blob76b51f53d1a1bc1f27a59882b389fd3fce8502b3
1 from __future__ import annotations
3 import collections
4 import collections.abc
5 import contextlib
6 import functools
7 import http.cookies
8 import inspect
9 import itertools
10 import re
11 import typing
12 import xml.etree.ElementTree
14 from ._utils import (
15 IDENTITY,
16 NO_DEFAULT,
17 ExtractorError,
18 LazyList,
19 deprecation_warning,
20 get_elements_html_by_class,
21 get_elements_html_by_attribute,
22 get_elements_by_attribute,
23 get_element_by_class,
24 get_element_html_by_attribute,
25 get_element_by_attribute,
26 get_element_html_by_id,
27 get_element_by_id,
28 get_element_html_by_class,
29 get_elements_by_class,
30 get_element_text_and_html_by_tag,
31 is_iterable_like,
32 try_call,
33 url_or_none,
34 variadic,
38 def traverse_obj(
39 obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
40 casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
41 """
42 Safely traverse nested `dict`s and `Iterable`s
44 >>> obj = [{}, {"key": "value"}]
45 >>> traverse_obj(obj, (1, "key"))
46 'value'
48 Each of the provided `paths` is tested and the first producing a valid result will be returned.
49 The next path will also be tested if the path branched but no results could be found.
50 Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
51 `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
52 Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
54 The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
56 The keys in the path can be one of:
57 - `None`: Return the current object.
58 - `set`: Requires the only item in the set to be a type or function,
59 like `{type}`/`{type, type, ...}`/`{func}`. If a `type`, return only
60 values of this type. If a function, returns `func(obj)`.
61 - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
62 - `slice`: Branch out and return all values in `obj[key]`.
63 - `Ellipsis`: Branch out and return a list of all values.
64 - `tuple`/`list`: Branch out and return a list of all matching values.
65 Read as: `[traverse_obj(obj, branch) for branch in branches]`.
66 - `function`: Branch out and return values filtered by the function.
67 Read as: `[value for key, value in obj if function(key, value)]`.
68 For `Iterable`s, `key` is the index of the value.
69 For `re.Match`es, `key` is the group number (0 = full match)
70 as well as additionally any group names, if given.
71 - `dict`: Transform the current object and return a matching dict.
72 Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
73 - `any`-builtin: Take the first matching object and return it, resetting branching.
74 - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
75 - `filter`-builtin: Return the value if it is truthy, `None` otherwise.
77 `tuple`, `list`, and `dict` all support nested paths and branches.
79 @params paths Paths by which to traverse.
80 @param default Value to return if the paths do not match.
81 If the last key in the path is a `dict`, it will apply to each value inside
82 the dict instead, depth first. Try to avoid if using nested `dict` keys.
83 @param expected_type If a `type`, only accept final values of this type.
84 If any other callable, try to call the function on each result.
85 If the last key in the path is a `dict`, it will apply to each value inside
86 the dict instead, recursively. This does respect branching paths.
87 @param get_all If `False`, return the first matching result, otherwise all matching ones.
88 @param casesense If `False`, consider string dictionary keys as case insensitive.
90 `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
92 @param traverse_string Whether to traverse into objects as strings.
93 If `True`, any non-compatible object will first be
94 converted into a string and then traversed into.
95 The return value of that path will be a string instead,
96 not respecting any further branching.
99 @returns The result of the object traversal.
100 If successful, `get_all=True`, and the path branches at least once,
101 then a list of results is returned instead.
102 If no `default` is given and the last path branches, a `list` of results
103 is always returned. If a path ends on a `dict` that result will always be a `dict`.
105 if is_user_input is not NO_DEFAULT:
106 deprecation_warning('The is_user_input parameter is deprecated and no longer works')
108 casefold = lambda k: k.casefold() if isinstance(k, str) else k
110 if isinstance(expected_type, type):
111 type_test = lambda val: val if isinstance(val, expected_type) else None
112 else:
113 type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
115 def apply_key(key, obj, is_last):
116 branching = False
117 result = None
119 if obj is None and traverse_string:
120 if key is ... or callable(key) or isinstance(key, slice):
121 branching = True
122 result = ()
124 elif key is None:
125 result = obj
127 elif isinstance(key, set):
128 item = next(iter(key))
129 if len(key) > 1 or isinstance(item, type):
130 assert all(isinstance(item, type) for item in key)
131 if isinstance(obj, tuple(key)):
132 result = obj
133 else:
134 result = try_call(item, args=(obj,))
136 elif isinstance(key, (list, tuple)):
137 branching = True
138 result = itertools.chain.from_iterable(
139 apply_path(obj, branch, is_last)[0] for branch in key)
141 elif key is ...:
142 branching = True
143 if isinstance(obj, http.cookies.Morsel):
144 obj = dict(obj, key=obj.key, value=obj.value)
145 if isinstance(obj, collections.abc.Mapping):
146 result = obj.values()
147 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
148 result = obj
149 elif isinstance(obj, re.Match):
150 result = obj.groups()
151 elif traverse_string:
152 branching = False
153 result = str(obj)
154 else:
155 result = ()
157 elif callable(key):
158 branching = True
159 if isinstance(obj, http.cookies.Morsel):
160 obj = dict(obj, key=obj.key, value=obj.value)
161 if isinstance(obj, collections.abc.Mapping):
162 iter_obj = obj.items()
163 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
164 iter_obj = enumerate(obj)
165 elif isinstance(obj, re.Match):
166 iter_obj = itertools.chain(
167 enumerate((obj.group(), *obj.groups())),
168 obj.groupdict().items())
169 elif traverse_string:
170 branching = False
171 iter_obj = enumerate(str(obj))
172 else:
173 iter_obj = ()
175 result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
176 if not branching: # string traversal
177 result = ''.join(result)
179 elif isinstance(key, dict):
180 iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
181 result = {
182 k: v if v is not None else default for k, v in iter_obj
183 if v is not None or default is not NO_DEFAULT
184 } or None
186 elif isinstance(obj, collections.abc.Mapping):
187 if isinstance(obj, http.cookies.Morsel):
188 obj = dict(obj, key=obj.key, value=obj.value)
189 result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
190 next((v for k, v in obj.items() if casefold(k) == key), None))
192 elif isinstance(obj, re.Match):
193 if isinstance(key, int) or casesense:
194 with contextlib.suppress(IndexError):
195 result = obj.group(key)
197 elif isinstance(key, str):
198 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
200 elif isinstance(key, (int, slice)):
201 if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
202 branching = isinstance(key, slice)
203 with contextlib.suppress(IndexError):
204 result = obj[key]
205 elif traverse_string:
206 with contextlib.suppress(IndexError):
207 result = str(obj)[key]
209 elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
210 xpath, _, special = key.rpartition('/')
211 if not special.startswith('@') and not special.endswith('()'):
212 xpath = key
213 special = None
215 # Allow abbreviations of relative paths, absolute paths error
216 if xpath.startswith('/'):
217 xpath = f'.{xpath}'
218 elif xpath and not xpath.startswith('./'):
219 xpath = f'./{xpath}'
221 def apply_specials(element):
222 if special is None:
223 return element
224 if special == '@':
225 return element.attrib
226 if special.startswith('@'):
227 return try_call(element.attrib.get, args=(special[1:],))
228 if special == 'text()':
229 return element.text
230 raise SyntaxError(f'apply_specials is missing case for {special!r}')
232 if xpath:
233 result = list(map(apply_specials, obj.iterfind(xpath)))
234 else:
235 result = apply_specials(obj)
237 return branching, result if branching else (result,)
239 def lazy_last(iterable):
240 iterator = iter(iterable)
241 prev = next(iterator, NO_DEFAULT)
242 if prev is NO_DEFAULT:
243 return
245 for item in iterator:
246 yield False, prev
247 prev = item
249 yield True, prev
251 def apply_path(start_obj, path, test_type):
252 objs = (start_obj,)
253 has_branched = False
255 key = None
256 for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
257 if not casesense and isinstance(key, str):
258 key = key.casefold()
260 if key in (any, all):
261 has_branched = False
262 filtered_objs = (obj for obj in objs if obj not in (None, {}))
263 if key is any:
264 objs = (next(filtered_objs, None),)
265 else:
266 objs = (list(filtered_objs),)
267 continue
269 if key is filter:
270 objs = filter(None, objs)
271 continue
273 if __debug__ and callable(key):
274 # Verify function signature
275 inspect.signature(key).bind(None, None)
277 new_objs = []
278 for obj in objs:
279 branching, results = apply_key(key, obj, last)
280 has_branched |= branching
281 new_objs.append(results)
283 objs = itertools.chain.from_iterable(new_objs)
285 if test_type and not isinstance(key, (dict, list, tuple)):
286 objs = map(type_test, objs)
288 return objs, has_branched, isinstance(key, dict)
290 def _traverse_obj(obj, path, allow_empty, test_type):
291 results, has_branched, is_dict = apply_path(obj, path, test_type)
292 results = LazyList(item for item in results if item not in (None, {}))
293 if get_all and has_branched:
294 if results:
295 return results.exhaust()
296 if allow_empty:
297 return [] if default is NO_DEFAULT else default
298 return None
300 return results[0] if results else {} if allow_empty and is_dict else None
302 for index, path in enumerate(paths, 1):
303 is_last = index == len(paths)
304 try:
305 result = _traverse_obj(obj, path, is_last, True)
306 if result is not None:
307 return result
308 except _RequiredError as e:
309 if is_last:
310 # Reraise to get cleaner stack trace
311 raise ExtractorError(e.orig_msg, expected=e.expected) from None
313 return None if default is NO_DEFAULT else default
316 def value(value, /):
317 return lambda _: value
320 def require(name, /, *, expected=False):
321 def func(value):
322 if value is None:
323 raise _RequiredError(f'Unable to extract {name}', expected=expected)
325 return value
327 return func
330 class _RequiredError(ExtractorError):
331 pass
334 @typing.overload
335 def subs_list_to_dict(*, lang: str | None = 'und', ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
338 @typing.overload
339 def subs_list_to_dict(subs: list[dict] | None, /, *, lang: str | None = 'und', ext: str | None = None) -> dict[str, list[dict]]: ...
342 def subs_list_to_dict(subs: list[dict] | None = None, /, *, lang='und', ext=None):
344 Convert subtitles from a traversal into a subtitle dict.
345 The path should have an `all` immediately before this function.
347 Arguments:
348 `ext` The default value for `ext` in the subtitle dict
350 In the dict you can set the following additional items:
351 `id` The subtitle id to sort the dict into
352 `quality` The sort order for each subtitle
354 if subs is None:
355 return functools.partial(subs_list_to_dict, lang=lang, ext=ext)
357 result = collections.defaultdict(list)
359 for sub in subs:
360 if not url_or_none(sub.get('url')) and not sub.get('data'):
361 continue
362 sub_id = sub.pop('id', None)
363 if not isinstance(sub_id, str):
364 if not lang:
365 continue
366 sub_id = lang
367 sub_ext = sub.get('ext')
368 if not isinstance(sub_ext, str):
369 if not ext:
370 sub.pop('ext', None)
371 else:
372 sub['ext'] = ext
373 result[sub_id].append(sub)
374 result = dict(result)
376 for subs in result.values():
377 subs.sort(key=lambda x: x.pop('quality', 0) or 0)
379 return result
382 @typing.overload
383 def find_element(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
386 @typing.overload
387 def find_element(*, cls: str, html=False): ...
390 @typing.overload
391 def find_element(*, id: str, tag: str | None = None, html=False, regex=False): ...
394 @typing.overload
395 def find_element(*, tag: str, html=False, regex=False): ...
398 def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False, regex=False):
399 # deliberately using `id=` and `cls=` for ease of readability
400 assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
401 ANY_TAG = r'[\w:.-]+'
403 if attr and value:
404 assert not cls, 'Cannot match both attr and cls'
405 assert not id, 'Cannot match both attr and id'
406 func = get_element_html_by_attribute if html else get_element_by_attribute
407 return functools.partial(func, attr, value, tag=tag or ANY_TAG, escape_value=not regex)
409 elif cls:
410 assert not id, 'Cannot match both cls and id'
411 assert tag is None, 'Cannot match both cls and tag'
412 assert not regex, 'Cannot use regex with cls'
413 func = get_element_html_by_class if html else get_element_by_class
414 return functools.partial(func, cls)
416 elif id:
417 func = get_element_html_by_id if html else get_element_by_id
418 return functools.partial(func, id, tag=tag or ANY_TAG, escape_value=not regex)
420 index = int(bool(html))
421 return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
424 @typing.overload
425 def find_elements(*, cls: str, html=False): ...
428 @typing.overload
429 def find_elements(*, attr: str, value: str, tag: str | None = None, html=False, regex=False): ...
432 def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False, regex=False):
433 # deliberately using `cls=` for ease of readability
434 assert cls or (attr and value), 'One of cls or (attr AND value) is required'
436 if attr and value:
437 assert not cls, 'Cannot match both attr and cls'
438 func = get_elements_html_by_attribute if html else get_elements_by_attribute
439 return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+', escape_value=not regex)
441 assert not tag, 'Cannot match both cls and tag'
442 assert not regex, 'Cannot use regex with cls'
443 func = get_elements_html_by_class if html else get_elements_by_class
444 return functools.partial(func, cls)
447 def trim_str(*, start=None, end=None):
448 def trim(s):
449 if s is None:
450 return None
451 start_idx = 0
452 if start and s.startswith(start):
453 start_idx = len(start)
454 if end and s.endswith(end):
455 return s[start_idx:-len(end)]
456 return s[start_idx:]
458 return trim
461 def unpack(func, **kwargs):
462 @functools.wraps(func)
463 def inner(items):
464 return func(*items, **kwargs)
466 return inner
469 def get_first(obj, *paths, **kwargs):
470 return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
473 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
474 for val in map(d.get, variadic(key_or_keys)):
475 if val is not None and (val or not skip_false_values):
476 return val
477 return default