dev: add a pyupgrade check step
[git-cola.git] / cola / utils.py
blob853f56087052dc46689ef6c44dda0e3891f499b6
1 """Miscellaneous utility functions"""
2 import copy
3 import os
4 import re
5 import shlex
6 import sys
7 import tempfile
8 import time
9 import traceback
11 from . import core
12 from . import compat
15 def asint(obj, default=0):
16 """Make any value into an int, even if the cast fails"""
17 try:
18 value = int(obj)
19 except (TypeError, ValueError):
20 value = default
21 return value
24 def clamp(value, low, high):
25 """Clamp a value to the specified range"""
26 return min(high, max(low, value))
29 def epoch_millis():
30 return int(time.time() * 1000)
33 def add_parents(paths):
34 """Iterate over each item in the set and add its parent directories."""
35 all_paths = set()
36 for path in paths:
37 while '//' in path:
38 path = path.replace('//', '/')
39 all_paths.add(path)
40 if '/' in path:
41 parent_dir = dirname(path)
42 while parent_dir:
43 all_paths.add(parent_dir)
44 parent_dir = dirname(parent_dir)
45 return all_paths
48 def format_exception(exc):
49 """Format an exception object for display"""
50 exc_type, exc_value, exc_tb = sys.exc_info()
51 details = traceback.format_exception(exc_type, exc_value, exc_tb)
52 details = '\n'.join(map(core.decode, details))
53 if hasattr(exc, 'msg'):
54 msg = exc.msg
55 else:
56 msg = core.decode(repr(exc))
57 return (msg, details)
60 def sublist(values, remove):
61 """Subtracts list b from list a and returns the resulting list."""
62 # conceptually, c = a - b
63 result = []
64 for item in values:
65 if item not in remove:
66 result.append(item)
67 return result
70 __grep_cache = {}
73 def grep(pattern, items, squash=True):
74 """Greps a list for items that match a pattern
76 :param squash: If only one item matches, return just that item
77 :returns: List of matching items
79 """
80 isdict = isinstance(items, dict)
81 if pattern in __grep_cache:
82 regex = __grep_cache[pattern]
83 else:
84 regex = __grep_cache[pattern] = re.compile(pattern)
85 matched = []
86 matchdict = {}
87 for item in items:
88 match = regex.match(item)
89 if not match:
90 continue
91 groups = match.groups()
92 if not groups:
93 subitems = match.group(0)
94 else:
95 if len(groups) == 1:
96 subitems = groups[0]
97 else:
98 subitems = list(groups)
99 if isdict:
100 matchdict[item] = items[item]
101 else:
102 matched.append(subitems)
104 if isdict:
105 result = matchdict
106 elif squash and len(matched) == 1:
107 result = matched[0]
108 else:
109 result = matched
111 return result
114 def basename(path):
116 An os.path.basename() implementation that always uses '/'
118 Avoid os.path.basename because git's output always
119 uses '/' regardless of platform.
122 return path.rsplit('/', 1)[-1]
125 def strip_one(path):
126 """Strip one level of directory"""
127 return path.strip('/').split('/', 1)[-1]
130 def dirname(path, current_dir=''):
132 An os.path.dirname() implementation that always uses '/'
134 Avoid os.path.dirname because git's output always
135 uses '/' regardless of platform.
138 while '//' in path:
139 path = path.replace('//', '/')
140 path_dirname = path.rsplit('/', 1)[0]
141 if path_dirname == path:
142 return current_dir
143 return path.rsplit('/', 1)[0]
146 def splitpath(path):
147 """Split paths using '/' regardless of platform"""
148 return path.split('/')
151 def split(name):
152 """Split a path-like name. Returns tuple "(head, tail)" where "tail" is
153 everything after the final slash. The "head" may be empty.
155 This is the same as os.path.split() but only uses '/' as the delimiter.
157 >>> split('a/b/c')
158 ('a/b', 'c')
160 >>> split('xyz')
161 ('', 'xyz')
164 return (dirname(name), basename(name))
167 def join(*paths):
168 """Join paths using '/' regardless of platform
170 >>> join('a', 'b', 'c')
171 'a/b/c'
174 return '/'.join(paths)
177 def normalize_slash(value):
178 """Strip and normalize slashes in a string
180 >>> normalize_slash('///Meow///Cat///')
181 'Meow/Cat'
184 value = value.strip('/')
185 new_value = value.replace('//', '/')
186 while new_value != value:
187 value = new_value
188 new_value = value.replace('//', '/')
189 return value
192 def pathjoin(paths):
193 """Join a list of paths using '/' regardless of platform
195 >>> pathjoin(['a', 'b', 'c'])
196 'a/b/c'
199 return join(*paths)
202 def pathset(path):
203 """Return all of the path components for the specified path
205 >>> pathset('foo/bar/baz') == ['foo', 'foo/bar', 'foo/bar/baz']
206 True
209 result = []
210 parts = splitpath(path)
211 prefix = ''
212 for part in parts:
213 result.append(prefix + part)
214 prefix += part + '/'
216 return result
219 def select_directory(paths):
220 """Return the first directory in a list of paths"""
221 if not paths:
222 return core.getcwd()
224 for path in paths:
225 if core.isdir(path):
226 return path
228 return os.path.dirname(paths[0]) or core.getcwd()
231 def strip_prefix(prefix, string):
232 """Return string, without the prefix. Blow up if string doesn't
233 start with prefix."""
234 assert string.startswith(prefix)
235 return string[len(prefix) :]
238 def tablength(word, tabwidth):
239 """Return length of a word taking tabs into account
241 >>> tablength("\\t\\t\\t\\tX", 8)
245 return len(word.replace('\t', '')) + word.count('\t') * tabwidth
248 def _shell_split_py2(value):
249 """Python2 requires bytes inputs to shlex.split(). Returns [unicode]"""
250 try:
251 result = shlex.split(core.encode(value))
252 except ValueError:
253 result = core.encode(value).strip().split()
254 # Decode to Unicode strings
255 return [core.decode(arg) for arg in result]
258 def _shell_split_py3(value):
259 """Python3 requires Unicode inputs to shlex.split(). Convert to Unicode"""
260 try:
261 result = shlex.split(value)
262 except ValueError:
263 result = core.decode(value).strip().split()
264 # Already Unicode
265 return result
268 def shell_split(value):
269 if compat.PY2:
270 # Encode before calling split()
271 values = _shell_split_py2(value)
272 else:
273 # Python3 does not need the encode/decode dance
274 values = _shell_split_py3(value)
275 return values
278 def tmp_filename(label, suffix=''):
279 label = 'git-cola-' + label.replace('/', '-').replace('\\', '-')
280 with tempfile.NamedTemporaryFile(
281 prefix=label + '-', suffix=suffix, delete=False
282 ) as handle:
283 return handle.name
286 def is_linux():
287 """Is this a Linux machine?"""
288 return sys.platform.startswith('linux')
291 def is_debian():
292 """Is this a Debian/Linux machine?"""
293 return os.path.exists('/usr/bin/apt-get')
296 def is_darwin():
297 """Is this a macOS machine?"""
298 return sys.platform == 'darwin'
301 def is_win32():
302 """Return True on win32"""
303 return sys.platform in {'win32', 'cygwin'}
306 def launch_default_app(paths):
307 """Execute the default application on the specified paths"""
308 if is_win32():
309 for path in paths:
310 if hasattr(os, 'startfile'):
311 os.startfile(os.path.abspath(path))
312 return
314 if is_darwin():
315 launcher = 'open'
316 else:
317 launcher = 'xdg-open'
319 core.fork([launcher] + paths)
322 def expandpath(path):
323 """Expand ~user/ and environment $variables"""
324 path = os.path.expandvars(path)
325 if path.startswith('~'):
326 path = os.path.expanduser(path)
327 return path
330 class Group:
331 """Operate on a collection of objects as a single unit"""
333 def __init__(self, *members):
334 self._members = members
336 def __getattr__(self, name):
337 """Return a function that relays calls to the group"""
339 def relay(*args, **kwargs):
340 for member in self._members:
341 method = getattr(member, name)
342 method(*args, **kwargs)
344 setattr(self, name, relay)
345 return relay
348 class Proxy:
349 """Wrap an object and override attributes"""
351 def __init__(self, obj, **overrides):
352 self._obj = obj
353 for k, v in overrides.items():
354 setattr(self, k, v)
356 def __getattr__(self, name):
357 return getattr(self._obj, name)
360 def slice_func(input_items, map_func):
361 """Slice input_items and call `map_func` over every slice
363 This exists because of "errno: Argument list too long"
366 # This comment appeared near the top of include/linux/binfmts.h
367 # in the Linux source tree:
369 # /*
370 # * MAX_ARG_PAGES defines the number of pages allocated for arguments
371 # * and envelope for the new program. 32 should suffice, this gives
372 # * a maximum env+arg of 128kB w/4KB pages!
373 # */
374 # #define MAX_ARG_PAGES 32
376 # 'size' is a heuristic to keep things highly performant by minimizing
377 # the number of slices. If we wanted it to run as few commands as
378 # possible we could call "getconf ARG_MAX" and make a better guess,
379 # but it's probably not worth the complexity (and the extra call to
380 # getconf that we can't do on Windows anyways).
382 # In my testing, getconf ARG_MAX on Mac OS X Mountain Lion reported
383 # 262144 and Debian/Linux-x86_64 reported 2097152.
385 # The hard-coded max_arg_len value is safely below both of these
386 # real-world values.
388 # 4K pages x 32 MAX_ARG_PAGES
389 max_arg_len = (32 * 4096) // 4 # allow plenty of space for the environment
390 max_filename_len = 256
391 size = max_arg_len // max_filename_len
393 status = 0
394 outs = []
395 errs = []
397 items = copy.copy(input_items)
398 while items:
399 stat, out, err = map_func(items[:size])
400 if stat < 0:
401 status = min(stat, status)
402 else:
403 status = max(stat, status)
404 outs.append(out)
405 errs.append(err)
406 items = items[size:]
408 return (status, '\n'.join(outs), '\n'.join(errs))
411 class Sequence:
412 def __init__(self, sequence):
413 self.sequence = sequence
415 def index(self, item, default=-1):
416 try:
417 idx = self.sequence.index(item)
418 except ValueError:
419 idx = default
420 return idx
422 def __getitem__(self, idx):
423 return self.sequence[idx]
426 def catch_runtime_error(func, *args, **kwargs):
427 """Run the function safely.
429 Catch RuntimeError to avoid tracebacks during application shutdown.
432 # Signals and callbacks can sometimes get triggered during application shutdown.
433 # This can happen when exiting while background tasks are still processing.
434 # Guard against this by making this operation a no-op.
435 try:
436 valid = True
437 result = func(*args, **kwargs)
438 except RuntimeError:
439 valid = False
440 result = None
441 return (valid, result)