[ie/mlbtv] Fix extractor (#10515)
[yt-dlp.git] / yt_dlp / utils / traversal.py
blob96eb2eddf5296d6af7c374111a0fc672892a3542
1 import collections.abc
2 import contextlib
3 import http.cookies
4 import inspect
5 import itertools
6 import re
7 import xml.etree.ElementTree
9 from ._utils import (
10 IDENTITY,
11 NO_DEFAULT,
12 LazyList,
13 deprecation_warning,
14 is_iterable_like,
15 try_call,
16 variadic,
20 def traverse_obj(
21 obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
22 casesense=True, is_user_input=NO_DEFAULT, traverse_string=False):
23 """
24 Safely traverse nested `dict`s and `Iterable`s
26 >>> obj = [{}, {"key": "value"}]
27 >>> traverse_obj(obj, (1, "key"))
28 'value'
30 Each of the provided `paths` is tested and the first producing a valid result will be returned.
31 The next path will also be tested if the path branched but no results could be found.
32 Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
33 `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
34 Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
36 The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
38 The keys in the path can be one of:
39 - `None`: Return the current object.
40 - `set`: Requires the only item in the set to be a type or function,
41 like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only
42 values of this type. If a function, returns `func(obj)`.
43 - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
44 - `slice`: Branch out and return all values in `obj[key]`.
45 - `Ellipsis`: Branch out and return a list of all values.
46 - `tuple`/`list`: Branch out and return a list of all matching values.
47 Read as: `[traverse_obj(obj, branch) for branch in branches]`.
48 - `function`: Branch out and return values filtered by the function.
49 Read as: `[value for key, value in obj if function(key, value)]`.
50 For `Iterable`s, `key` is the index of the value.
51 For `re.Match`es, `key` is the group number (0 = full match)
52 as well as additionally any group names, if given.
53 - `dict`: Transform the current object and return a matching dict.
54 Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
55 - `any`-builtin: Take the first matching object and return it, resetting branching.
56 - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
58 `tuple`, `list`, and `dict` all support nested paths and branches.
60 @params paths Paths which to traverse by.
61 @param default Value to return if the paths do not match.
62 If the last key in the path is a `dict`, it will apply to each value inside
63 the dict instead, depth first. Try to avoid if using nested `dict` keys.
64 @param expected_type If a `type`, only accept final values of this type.
65 If any other callable, try to call the function on each result.
66 If the last key in the path is a `dict`, it will apply to each value inside
67 the dict instead, recursively. This does respect branching paths.
68 @param get_all If `False`, return the first matching result, otherwise all matching ones.
69 @param casesense If `False`, consider string dictionary keys as case insensitive.
71 `traverse_string` is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
73 @param traverse_string Whether to traverse into objects as strings.
74 If `True`, any non-compatible object will first be
75 converted into a string and then traversed into.
76 The return value of that path will be a string instead,
77 not respecting any further branching.
80 @returns The result of the object traversal.
81 If successful, `get_all=True`, and the path branches at least once,
82 then a list of results is returned instead.
83 If no `default` is given and the last path branches, a `list` of results
84 is always returned. If a path ends on a `dict` that result will always be a `dict`.
85 """
86 if is_user_input is not NO_DEFAULT:
87 deprecation_warning('The is_user_input parameter is deprecated and no longer works')
89 casefold = lambda k: k.casefold() if isinstance(k, str) else k
91 if isinstance(expected_type, type):
92 type_test = lambda val: val if isinstance(val, expected_type) else None
93 else:
94 type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
96 def apply_key(key, obj, is_last):
97 branching = False
98 result = None
100 if obj is None and traverse_string:
101 if key is ... or callable(key) or isinstance(key, slice):
102 branching = True
103 result = ()
105 elif key is None:
106 result = obj
108 elif isinstance(key, set):
109 item = next(iter(key))
110 if len(key) > 1 or isinstance(item, type):
111 assert all(isinstance(item, type) for item in key)
112 if isinstance(obj, tuple(key)):
113 result = obj
114 else:
115 result = try_call(item, args=(obj,))
117 elif isinstance(key, (list, tuple)):
118 branching = True
119 result = itertools.chain.from_iterable(
120 apply_path(obj, branch, is_last)[0] for branch in key)
122 elif key is ...:
123 branching = True
124 if isinstance(obj, http.cookies.Morsel):
125 obj = dict(obj, key=obj.key, value=obj.value)
126 if isinstance(obj, collections.abc.Mapping):
127 result = obj.values()
128 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
129 result = obj
130 elif isinstance(obj, re.Match):
131 result = obj.groups()
132 elif traverse_string:
133 branching = False
134 result = str(obj)
135 else:
136 result = ()
138 elif callable(key):
139 branching = True
140 if isinstance(obj, http.cookies.Morsel):
141 obj = dict(obj, key=obj.key, value=obj.value)
142 if isinstance(obj, collections.abc.Mapping):
143 iter_obj = obj.items()
144 elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
145 iter_obj = enumerate(obj)
146 elif isinstance(obj, re.Match):
147 iter_obj = itertools.chain(
148 enumerate((obj.group(), *obj.groups())),
149 obj.groupdict().items())
150 elif traverse_string:
151 branching = False
152 iter_obj = enumerate(str(obj))
153 else:
154 iter_obj = ()
156 result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
157 if not branching: # string traversal
158 result = ''.join(result)
160 elif isinstance(key, dict):
161 iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
162 result = {
163 k: v if v is not None else default for k, v in iter_obj
164 if v is not None or default is not NO_DEFAULT
165 } or None
167 elif isinstance(obj, collections.abc.Mapping):
168 if isinstance(obj, http.cookies.Morsel):
169 obj = dict(obj, key=obj.key, value=obj.value)
170 result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
171 next((v for k, v in obj.items() if casefold(k) == key), None))
173 elif isinstance(obj, re.Match):
174 if isinstance(key, int) or casesense:
175 with contextlib.suppress(IndexError):
176 result = obj.group(key)
178 elif isinstance(key, str):
179 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
181 elif isinstance(key, (int, slice)):
182 if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
183 branching = isinstance(key, slice)
184 with contextlib.suppress(IndexError):
185 result = obj[key]
186 elif traverse_string:
187 with contextlib.suppress(IndexError):
188 result = str(obj)[key]
190 elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
191 xpath, _, special = key.rpartition('/')
192 if not special.startswith('@') and not special.endswith('()'):
193 xpath = key
194 special = None
196 # Allow abbreviations of relative paths, absolute paths error
197 if xpath.startswith('/'):
198 xpath = f'.{xpath}'
199 elif xpath and not xpath.startswith('./'):
200 xpath = f'./{xpath}'
202 def apply_specials(element):
203 if special is None:
204 return element
205 if special == '@':
206 return element.attrib
207 if special.startswith('@'):
208 return try_call(element.attrib.get, args=(special[1:],))
209 if special == 'text()':
210 return element.text
211 raise SyntaxError(f'apply_specials is missing case for {special!r}')
213 if xpath:
214 result = list(map(apply_specials, obj.iterfind(xpath)))
215 else:
216 result = apply_specials(obj)
218 return branching, result if branching else (result,)
220 def lazy_last(iterable):
221 iterator = iter(iterable)
222 prev = next(iterator, NO_DEFAULT)
223 if prev is NO_DEFAULT:
224 return
226 for item in iterator:
227 yield False, prev
228 prev = item
230 yield True, prev
232 def apply_path(start_obj, path, test_type):
233 objs = (start_obj,)
234 has_branched = False
236 key = None
237 for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
238 if not casesense and isinstance(key, str):
239 key = key.casefold()
241 if key in (any, all):
242 has_branched = False
243 filtered_objs = (obj for obj in objs if obj not in (None, {}))
244 if key is any:
245 objs = (next(filtered_objs, None),)
246 else:
247 objs = (list(filtered_objs),)
248 continue
250 if __debug__ and callable(key):
251 # Verify function signature
252 inspect.signature(key).bind(None, None)
254 new_objs = []
255 for obj in objs:
256 branching, results = apply_key(key, obj, last)
257 has_branched |= branching
258 new_objs.append(results)
260 objs = itertools.chain.from_iterable(new_objs)
262 if test_type and not isinstance(key, (dict, list, tuple)):
263 objs = map(type_test, objs)
265 return objs, has_branched, isinstance(key, dict)
267 def _traverse_obj(obj, path, allow_empty, test_type):
268 results, has_branched, is_dict = apply_path(obj, path, test_type)
269 results = LazyList(item for item in results if item not in (None, {}))
270 if get_all and has_branched:
271 if results:
272 return results.exhaust()
273 if allow_empty:
274 return [] if default is NO_DEFAULT else default
275 return None
277 return results[0] if results else {} if allow_empty and is_dict else None
279 for index, path in enumerate(paths, 1):
280 result = _traverse_obj(obj, path, index == len(paths), True)
281 if result is not None:
282 return result
284 return None if default is NO_DEFAULT else default
287 def get_first(obj, *paths, **kwargs):
288 return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)
291 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
292 for val in map(d.get, variadic(key_or_keys)):
293 if val is not None and (val or not skip_false_values):
294 return val
295 return default