[utils] Fix `find_element` by class (#11402)
[yt-dlp3.git] / yt_dlp / utils / traversal.py
blob0eef817eaacee0c31917a9233897a7efd80d5e2f
1 from __future__ import annotations
3 import collections
4 import collections.abc
5 import contextlib
6 import functools
7 import http.cookies
8 import inspect
9 import itertools
10 import re
11 import typing
12 import xml.etree.ElementTree
14 from ._utils import (
15 IDENTITY,
16 NO_DEFAULT,
17 ExtractorError,
18 LazyList,
19 deprecation_warning,
20 get_elements_html_by_class,
21 get_elements_html_by_attribute,
22 get_elements_by_attribute,
23 get_element_html_by_attribute,
24 get_element_by_attribute,
25 get_element_html_by_id,
26 get_element_by_id,
27 get_element_html_by_class,
28 get_elements_by_class,
29 get_element_text_and_html_by_tag,
30 is_iterable_like,
31 try_call,
32 url_or_none,
33 variadic,
37 def traverse_obj(
38 obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
39 casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
40 """
41 Safely traverse nested `dict`s and `Iterable`s
43 >>> obj = [{}, {"key": "value"}]
44 >>> traverse_obj(obj, (1, "key"))
45 'value'
47 Each of the provided `paths` is tested and the first producing a valid result will be returned.
48 The next path will also be tested if the path branched but no results could be found.
49 Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
50 `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
51 Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
53 The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
55 The keys in the path can be one of:
56 - `None`: Return the current object.
57 - `set`: Requires the only item in the set to be a type or function,
58 like `{type}`/`{type, type, ...}`/`{func}`. If a `type`, return only
59 values of this type. If a function, returns `func(obj)`.
60 - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
61 - `slice`: Branch out and return all values in `obj[key]`.
62 - `Ellipsis`: Branch out and return a list of all values.
63 - `tuple`/`list`: Branch out and return a list of all matching values.
64 Read as: `[traverse_obj(obj, branch) for branch in branches]`.
65 - `function`: Branch out and return values filtered by the function.
66 Read as: `[value for key, value in obj if function(key, value)]`.
67 For `Iterable`s, `key` is the index of the value.
68 For `re.Match`es, `key` is the group number (0 = full match)
69 as well as additionally any group names, if given.
70 - `dict`: Transform the current object and return a matching dict.
71 Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
72 - `any`-builtin: Take the first matching object and return it, resetting branching.
73 - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
74 - `filter`-builtin: Return the value if it is truthy, `None` otherwise.
76 `tuple`, `list`, and `dict` all support nested paths and branches.
78 @params paths Paths by which to traverse.
79 @param default Value to return if the paths do not match.
80 If the last key in the path is a `dict`, it will apply to each value inside
81 the dict instead, depth first. Try to avoid if using nested `dict` keys.
82 @param expected_type If a `type`, only accept final values of this type.
83 If any other callable, try to call the function on each result.
84 If the last key in the path is a `dict`, it will apply to each value inside
85 the dict instead, recursively. This does respect branching paths.
86 @param get_all If `False`, return the first matching result, otherwise all matching ones.
87 @param casesense If `False`, consider string dictionary keys as case insensitive.
89 `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
91 @param traverse_string Whether to traverse into objects as strings.
92 If `True`, any non-compatible object will first be
93 converted into a string and then traversed into.
94 The return value of that path will be a string instead,
95 not respecting any further branching.
98 @returns The result of the object traversal.
99 If successful, `get_all=True`, and the path branches at least once,
100 then a list of results is returned instead.
101 If no `default` is given and the last path branches, a `list` of results
102 is always returned. If a path ends on a `dict` that result will always be a `dict`.
104 if is_user_input is not NO_DEFAULT:
105 deprecation_warning('The is_user_input parameter is deprecated and no longer works')
107 casefold = lambda k: k.casefold() if isinstance(k, str) else k
109 if isinstance(expected_type, type):
110 type_test = lambda val: val if isinstance(val, expected_type) else None
111 else:
112 type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
114 def apply_key(key, obj, is_last):
115 branching = False
116 result = None
118 if obj is None and traverse_string:
119 if key is ... or callable(key) or isinstance(key, slice):
120 branching = True
121 result = ()
123 elif key is None:
124 result = obj
126 elif isinstance(key, set):
127 item = next(iter(key))
128 if len(key) > 1 or isinstance(item, type):
129 assert all(isinstance(item, type) for item in key)
130 if isinstance(obj, tuple(key)):
131 result = obj
132 else:
133 result = try_call(item, args=(obj,))
135 elif isinstance(key, (list, tuple)):
136 branching = True
137 result = itertools.chain.from_iterable(
138 apply_path(obj, branch, is_last)[0] for branch in key)
140 elif key is ...:
141 branching = True
142 if isinstance(obj, http.cookies.Morsel):
143 obj = dict(obj, key=obj.key, value=obj.value)
144 if isinstance(obj, collections.abc.Mapping):
145 result = obj.values()
146 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
147 result = obj
148 elif isinstance(obj, re.Match):
149 result = obj.groups()
150 elif traverse_string:
151 branching = False
152 result = str(obj)
153 else:
154 result = ()
156 elif callable(key):
157 branching = True
158 if isinstance(obj, http.cookies.Morsel):
159 obj = dict(obj, key=obj.key, value=obj.value)
160 if isinstance(obj, collections.abc.Mapping):
161 iter_obj = obj.items()
162 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
163 iter_obj = enumerate(obj)
164 elif isinstance(obj, re.Match):
165 iter_obj = itertools.chain(
166 enumerate((obj.group(), *obj.groups())),
167 obj.groupdict().items())
168 elif traverse_string:
169 branching = False
170 iter_obj = enumerate(str(obj))
171 else:
172 iter_obj = ()
174 result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
175 if not branching: # string traversal
176 result = ''.join(result)
178 elif isinstance(key, dict):
179 iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
180 result = {
181 k: v if v is not None else default for k, v in iter_obj
182 if v is not None or default is not NO_DEFAULT
183 } or None
185 elif isinstance(obj, collections.abc.Mapping):
186 if isinstance(obj, http.cookies.Morsel):
187 obj = dict(obj, key=obj.key, value=obj.value)
188 result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
189 next((v for k, v in obj.items() if casefold(k) == key), None))
191 elif isinstance(obj, re.Match):
192 if isinstance(key, int) or casesense:
193 with contextlib.suppress(IndexError):
194 result = obj.group(key)
196 elif isinstance(key, str):
197 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
199 elif isinstance(key, (int, slice)):
200 if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
201 branching = isinstance(key, slice)
202 with contextlib.suppress(IndexError):
203 result = obj[key]
204 elif traverse_string:
205 with contextlib.suppress(IndexError):
206 result = str(obj)[key]
208 elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
209 xpath, _, special = key.rpartition('/')
210 if not special.startswith('@') and not special.endswith('()'):
211 xpath = key
212 special = None
214 # Allow abbreviations of relative paths, absolute paths error
215 if xpath.startswith('/'):
216 xpath = f'.{xpath}'
217 elif xpath and not xpath.startswith('./'):
218 xpath = f'./{xpath}'
220 def apply_specials(element):
221 if special is None:
222 return element
223 if special == '@':
224 return element.attrib
225 if special.startswith('@'):
226 return try_call(element.attrib.get, args=(special[1:],))
227 if special == 'text()':
228 return element.text
229 raise SyntaxError(f'apply_specials is missing case for {special!r}')
231 if xpath:
232 result = list(map(apply_specials, obj.iterfind(xpath)))
233 else:
234 result = apply_specials(obj)
236 return branching, result if branching else (result,)
238 def lazy_last(iterable):
239 iterator = iter(iterable)
240 prev = next(iterator, NO_DEFAULT)
241 if prev is NO_DEFAULT:
242 return
244 for item in iterator:
245 yield False, prev
246 prev = item
248 yield True, prev
250 def apply_path(start_obj, path, test_type):
251 objs = (start_obj,)
252 has_branched = False
254 key = None
255 for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
256 if not casesense and isinstance(key, str):
257 key = key.casefold()
259 if key in (any, all):
260 has_branched = False
261 filtered_objs = (obj for obj in objs if obj not in (None, {}))
262 if key is any:
263 objs = (next(filtered_objs, None),)
264 else:
265 objs = (list(filtered_objs),)
266 continue
268 if key is filter:
269 objs = filter(None, objs)
270 continue
272 if __debug__ and callable(key):
273 # Verify function signature
274 inspect.signature(key).bind(None, None)
276 new_objs = []
277 for obj in objs:
278 branching, results = apply_key(key, obj, last)
279 has_branched |= branching
280 new_objs.append(results)
282 objs = itertools.chain.from_iterable(new_objs)
284 if test_type and not isinstance(key, (dict, list, tuple)):
285 objs = map(type_test, objs)
287 return objs, has_branched, isinstance(key, dict)
289 def _traverse_obj(obj, path, allow_empty, test_type):
290 results, has_branched, is_dict = apply_path(obj, path, test_type)
291 results = LazyList(item for item in results if item not in (None, {}))
292 if get_all and has_branched:
293 if results:
294 return results.exhaust()
295 if allow_empty:
296 return [] if default is NO_DEFAULT else default
297 return None
299 return results[0] if results else {} if allow_empty and is_dict else None
301 for index, path in enumerate(paths, 1):
302 is_last = index == len(paths)
303 try:
304 result = _traverse_obj(obj, path, is_last, True)
305 if result is not None:
306 return result
307 except _RequiredError as e:
308 if is_last:
309 # Reraise to get cleaner stack trace
310 raise ExtractorError(e.orig_msg, expected=e.expected) from None
312 return None if default is NO_DEFAULT else default
315 def value(value, /):
316 return lambda _: value
319 def require(name, /, *, expected=False):
320 def func(value):
321 if value is None:
322 raise _RequiredError(f'Unable to extract {name}', expected=expected)
324 return value
326 return func
329 class _RequiredError(ExtractorError):
330 pass
333 @typing.overload
334 def subs_list_to_dict(*, ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
337 @typing.overload
338 def subs_list_to_dict(subs: list[dict] | None, /, *, ext: str | None = None) -> dict[str, list[dict]]: ...
341 def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None):
343 Convert subtitles from a traversal into a subtitle dict.
344 The path should have an `all` immediately before this function.
346 Arguments:
347 `ext` The default value for `ext` in the subtitle dict
349 In the dict you can set the following additional items:
350 `id` The subtitle id to sort the dict into
351 `quality` The sort order for each subtitle
353 if subs is None:
354 return functools.partial(subs_list_to_dict, ext=ext)
356 result = collections.defaultdict(list)
358 for sub in subs:
359 if not url_or_none(sub.get('url')) and not sub.get('data'):
360 continue
361 sub_id = sub.pop('id', None)
362 if sub_id is None:
363 continue
364 if ext is not None and not sub.get('ext'):
365 sub['ext'] = ext
366 result[sub_id].append(sub)
367 result = dict(result)
369 for subs in result.values():
370 subs.sort(key=lambda x: x.pop('quality', 0) or 0)
372 return result
375 @typing.overload
376 def find_element(*, attr: str, value: str, tag: str | None = None, html=False): ...
379 @typing.overload
380 def find_element(*, cls: str, html=False): ...
383 @typing.overload
384 def find_element(*, id: str, tag: str | None = None, html=False): ...
387 @typing.overload
388 def find_element(*, tag: str, html=False): ...
391 def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False):
392 # deliberately using `id=` and `cls=` for ease of readability
393 assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
394 ANY_TAG = r'[\w:.-]+'
396 if attr and value:
397 assert not cls, 'Cannot match both attr and cls'
398 assert not id, 'Cannot match both attr and id'
399 func = get_element_html_by_attribute if html else get_element_by_attribute
400 return functools.partial(func, attr, value, tag=tag or ANY_TAG)
402 elif cls:
403 assert not id, 'Cannot match both cls and id'
404 assert tag is None, 'Cannot match both cls and tag'
405 func = get_element_html_by_class if html else get_elements_by_class
406 return functools.partial(func, cls)
408 elif id:
409 func = get_element_html_by_id if html else get_element_by_id
410 return functools.partial(func, id, tag=tag or ANY_TAG)
412 index = int(bool(html))
413 return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
416 @typing.overload
417 def find_elements(*, cls: str, html=False): ...
420 @typing.overload
421 def find_elements(*, attr: str, value: str, tag: str | None = None, html=False): ...
424 def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False):
425 # deliberately using `cls=` for ease of readability
426 assert cls or (attr and value), 'One of cls or (attr AND value) is required'
428 if attr and value:
429 assert not cls, 'Cannot match both attr and cls'
430 func = get_elements_html_by_attribute if html else get_elements_by_attribute
431 return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+')
433 assert not tag, 'Cannot match both cls and tag'
434 func = get_elements_html_by_class if html else get_elements_by_class
435 return functools.partial(func, cls)
438 def get_first(obj, *paths, **kwargs):
439 return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
442 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
443 for val in map(d.get, variadic(key_or_keys)):
444 if val is not None and (val or not skip_false_values):
445 return val
446 return default