Coverage for /Volumes/workspace/python-progressbar/.tox/py38/lib/python3.8/site-packages/progressbar/utils.py: 95%
225 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-01 16:14 +0100
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-01 16:14 +0100
1from __future__ import annotations
3import atexit
4import datetime
5import io
6import logging
7import os
8import re
9import sys
10from types import TracebackType
11from typing import Iterable, Iterator, Type
13from python_utils import types
14from python_utils.converters import scale_1024
15from python_utils.terminal import get_terminal_size
16from python_utils.time import epoch, format_time, timedelta_to_seconds
18from progressbar import base
20if types.TYPE_CHECKING:
21 from .bar import ProgressBar, ProgressBarMixinBase
23assert timedelta_to_seconds
24assert get_terminal_size
25assert format_time
26assert scale_1024
27assert epoch
29StringT = types.TypeVar('StringT', bound=types.StringTypes)
31ANSI_TERMS = (
32 '([xe]|bv)term',
33 '(sco)?ansi',
34 'cygwin',
35 'konsole',
36 'linux',
37 'rxvt',
38 'screen',
39 'tmux',
40 'vt(10[02]|220|320)',
41)
42ANSI_TERM_RE = re.compile('^({})'.format('|'.join(ANSI_TERMS)), re.IGNORECASE)
45def is_ansi_terminal(
46 fd: base.IO, is_terminal: bool | None = None
47) -> bool: # pragma: no cover
48 if is_terminal is None:
49 # Jupyter Notebooks define this variable and support progress bars
50 if 'JPY_PARENT_PID' in os.environ:
51 is_terminal = True
52 # This works for newer versions of pycharm only. older versions there
53 # is no way to check.
54 elif os.environ.get('PYCHARM_HOSTED') == '1' and not os.environ.get(
55 'PYTEST_CURRENT_TEST'
56 ):
57 is_terminal = True
59 if is_terminal is None:
60 # check if we are writing to a terminal or not. typically a file object
61 # is going to return False if the instance has been overridden and
62 # isatty has not been defined we have no way of knowing so we will not
63 # use ansi. ansi terminals will typically define one of the 2
64 # environment variables.
65 try:
66 is_tty = fd.isatty()
67 # Try and match any of the huge amount of Linux/Unix ANSI consoles
68 if is_tty and ANSI_TERM_RE.match(os.environ.get('TERM', '')):
69 is_terminal = True
70 # ANSICON is a Windows ANSI compatible console
71 elif 'ANSICON' in os.environ:
72 is_terminal = True
73 else:
74 is_terminal = None
75 except Exception:
76 is_terminal = False
78 return bool(is_terminal)
81def is_terminal(fd: base.IO, is_terminal: bool | None = None) -> bool:
82 if is_terminal is None:
83 # Full ansi support encompasses what we expect from a terminal
84 is_terminal = is_ansi_terminal(fd) or None
86 if is_terminal is None:
87 # Allow a environment variable override
88 is_terminal = env_flag('PROGRESSBAR_IS_TERMINAL', None)
90 if is_terminal is None: # pragma: no cover
91 # Bare except because a lot can go wrong on different systems. If we do
92 # get a TTY we know this is a valid terminal
93 try:
94 is_terminal = fd.isatty()
95 except Exception:
96 is_terminal = False
98 return bool(is_terminal)
101def deltas_to_seconds(
102 *deltas,
103 default: types.Optional[types.Type[ValueError]] = ValueError,
104) -> int | float | None:
105 '''
106 Convert timedeltas and seconds as int to seconds as float while coalescing
108 >>> deltas_to_seconds(datetime.timedelta(seconds=1, milliseconds=234))
109 1.234
110 >>> deltas_to_seconds(123)
111 123.0
112 >>> deltas_to_seconds(1.234)
113 1.234
114 >>> deltas_to_seconds(None, 1.234)
115 1.234
116 >>> deltas_to_seconds(0, 1.234)
117 0.0
118 >>> deltas_to_seconds()
119 Traceback (most recent call last):
120 ...
121 ValueError: No valid deltas passed to `deltas_to_seconds`
122 >>> deltas_to_seconds(None)
123 Traceback (most recent call last):
124 ...
125 ValueError: No valid deltas passed to `deltas_to_seconds`
126 >>> deltas_to_seconds(default=0.0)
127 0.0
128 '''
129 for delta in deltas:
130 if delta is None:
131 continue
132 if isinstance(delta, datetime.timedelta):
133 return timedelta_to_seconds(delta)
134 elif not isinstance(delta, float):
135 return float(delta)
136 else:
137 return delta
139 if default is ValueError:
140 raise ValueError('No valid deltas passed to `deltas_to_seconds`')
141 else:
142 # mypy doesn't understand the `default is ValueError` check
143 return default # type: ignore
146def no_color(value: StringT) -> StringT:
147 '''
148 Return the `value` without ANSI escape codes
150 >>> no_color(b'\u001b[1234]abc') == b'abc'
151 True
152 >>> str(no_color(u'\u001b[1234]abc'))
153 'abc'
154 >>> str(no_color('\u001b[1234]abc'))
155 'abc'
156 '''
157 if isinstance(value, bytes):
158 pattern: bytes = '\\\u001b\\[.*?[@-~]'.encode()
159 return re.sub(pattern, b'', value) # type: ignore
160 else:
161 return re.sub(u'\x1b\\[.*?[@-~]', '', value) # type: ignore
164def len_color(value: types.StringTypes) -> int:
165 '''
166 Return the length of `value` without ANSI escape codes
168 >>> len_color(b'\u001b[1234]abc')
169 3
170 >>> len_color(u'\u001b[1234]abc')
171 3
172 >>> len_color('\u001b[1234]abc')
173 3
174 '''
175 return len(no_color(value))
178def env_flag(name: str, default: bool | None = None) -> bool | None:
179 '''
180 Accepts environt variables formatted as y/n, yes/no, 1/0, true/false,
181 on/off, and returns it as a boolean
183 If the environment variable is not defined, or has an unknown value,
184 returns `default`
185 '''
186 v = os.getenv(name)
187 if v and v.lower() in ('y', 'yes', 't', 'true', 'on', '1'):
188 return True
189 if v and v.lower() in ('n', 'no', 'f', 'false', 'off', '0'):
190 return False
191 return default
194class WrappingIO:
195 buffer: io.StringIO
196 target: base.IO
197 capturing: bool
198 listeners: set
199 needs_clear: bool = False
201 def __init__(
202 self,
203 target: base.IO,
204 capturing: bool = False,
205 listeners: types.Optional[types.Set[ProgressBar]] = None,
206 ) -> None:
207 self.buffer = io.StringIO()
208 self.target = target
209 self.capturing = capturing
210 self.listeners = listeners or set()
211 self.needs_clear = False
213 def write(self, value: str) -> int:
214 ret = 0
215 if self.capturing:
216 ret += self.buffer.write(value)
217 if '\n' in value: # pragma: no branch
218 self.needs_clear = True
219 for listener in self.listeners: # pragma: no branch
220 listener.update()
221 else:
222 ret += self.target.write(value)
223 if '\n' in value: # pragma: no branch
224 self.flush_target()
226 return ret
228 def flush(self) -> None:
229 self.buffer.flush()
231 def _flush(self) -> None:
232 value = self.buffer.getvalue()
233 if value:
234 self.flush()
235 self.target.write(value)
236 self.buffer.seek(0)
237 self.buffer.truncate(0)
238 self.needs_clear = False
240 # when explicitly flushing, always flush the target as well
241 self.flush_target()
243 def flush_target(self) -> None: # pragma: no cover
244 if not self.target.closed and getattr(self.target, 'flush'):
245 self.target.flush()
247 def __enter__(self) -> WrappingIO:
248 return self
250 def fileno(self) -> int:
251 return self.target.fileno()
253 def isatty(self) -> bool:
254 return self.target.isatty()
256 def read(self, n: int = -1) -> str:
257 return self.target.read(n)
259 def readable(self) -> bool:
260 return self.target.readable()
262 def readline(self, limit: int = -1) -> str:
263 return self.target.readline(limit)
265 def readlines(self, hint: int = -1) -> list[str]:
266 return self.target.readlines(hint)
268 def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
269 return self.target.seek(offset, whence)
271 def seekable(self) -> bool:
272 return self.target.seekable()
274 def tell(self) -> int:
275 return self.target.tell()
277 def truncate(self, size: types.Optional[int] = None) -> int:
278 return self.target.truncate(size)
280 def writable(self) -> bool:
281 return self.target.writable()
283 def writelines(self, lines: Iterable[str]) -> None:
284 return self.target.writelines(lines)
286 def close(self) -> None:
287 self.flush()
288 self.target.close()
290 def __next__(self) -> str:
291 return self.target.__next__()
293 def __iter__(self) -> Iterator[str]:
294 return self.target.__iter__()
296 def __exit__(
297 self,
298 __t: Type[BaseException] | None,
299 __value: BaseException | None,
300 __traceback: TracebackType | None,
301 ) -> None:
302 self.close()
305class StreamWrapper:
306 '''Wrap stdout and stderr globally'''
308 stdout: base.TextIO | WrappingIO
309 stderr: base.TextIO | WrappingIO
310 original_excepthook: types.Callable[
311 [
312 types.Type[BaseException],
313 BaseException,
314 TracebackType | None,
315 ],
316 None,
317 ]
318 # original_excepthook: types.Callable[
319 # [
320 # types.Type[BaseException],
321 # BaseException, TracebackType | None,
322 # ], None] | None
323 wrapped_stdout: int = 0
324 wrapped_stderr: int = 0
325 wrapped_excepthook: int = 0
326 capturing: int = 0
327 listeners: set
329 def __init__(self):
330 self.stdout = self.original_stdout = sys.stdout
331 self.stderr = self.original_stderr = sys.stderr
332 self.original_excepthook = sys.excepthook
333 self.wrapped_stdout = 0
334 self.wrapped_stderr = 0
335 self.wrapped_excepthook = 0
336 self.capturing = 0
337 self.listeners = set()
339 if env_flag('WRAP_STDOUT', default=False): # pragma: no cover
340 self.wrap_stdout()
342 if env_flag('WRAP_STDERR', default=False): # pragma: no cover
343 self.wrap_stderr()
345 def start_capturing(self, bar: ProgressBarMixinBase | None = None) -> None:
346 if bar: # pragma: no branch
347 self.listeners.add(bar)
349 self.capturing += 1
350 self.update_capturing()
352 def stop_capturing(self, bar: ProgressBarMixinBase | None = None) -> None:
353 if bar: # pragma: no branch
354 try:
355 self.listeners.remove(bar)
356 except KeyError:
357 pass
359 self.capturing -= 1
360 self.update_capturing()
362 def update_capturing(self) -> None: # pragma: no cover
363 if isinstance(self.stdout, WrappingIO):
364 self.stdout.capturing = self.capturing > 0
366 if isinstance(self.stderr, WrappingIO):
367 self.stderr.capturing = self.capturing > 0
369 if self.capturing <= 0:
370 self.flush()
372 def wrap(self, stdout: bool = False, stderr: bool = False) -> None:
373 if stdout:
374 self.wrap_stdout()
376 if stderr:
377 self.wrap_stderr()
379 def wrap_stdout(self) -> base.IO:
380 self.wrap_excepthook()
382 if not self.wrapped_stdout:
383 self.stdout = sys.stdout = WrappingIO( # type: ignore
384 self.original_stdout, listeners=self.listeners
385 )
386 self.wrapped_stdout += 1
388 return sys.stdout
390 def wrap_stderr(self) -> base.IO:
391 self.wrap_excepthook()
393 if not self.wrapped_stderr:
394 self.stderr = sys.stderr = WrappingIO( # type: ignore
395 self.original_stderr, listeners=self.listeners
396 )
397 self.wrapped_stderr += 1
399 return sys.stderr
401 def unwrap_excepthook(self) -> None:
402 if self.wrapped_excepthook:
403 self.wrapped_excepthook -= 1
404 sys.excepthook = self.original_excepthook
406 def wrap_excepthook(self) -> None:
407 if not self.wrapped_excepthook:
408 logger.debug('wrapping excepthook')
409 self.wrapped_excepthook += 1
410 sys.excepthook = self.excepthook
412 def unwrap(self, stdout: bool = False, stderr: bool = False) -> None:
413 if stdout:
414 self.unwrap_stdout()
416 if stderr:
417 self.unwrap_stderr()
419 def unwrap_stdout(self) -> None:
420 if self.wrapped_stdout > 1:
421 self.wrapped_stdout -= 1
422 else:
423 sys.stdout = self.original_stdout
424 self.wrapped_stdout = 0
426 def unwrap_stderr(self) -> None:
427 if self.wrapped_stderr > 1:
428 self.wrapped_stderr -= 1
429 else:
430 sys.stderr = self.original_stderr
431 self.wrapped_stderr = 0
433 def needs_clear(self) -> bool: # pragma: no cover
434 stdout_needs_clear = getattr(self.stdout, 'needs_clear', False)
435 stderr_needs_clear = getattr(self.stderr, 'needs_clear', False)
436 return stderr_needs_clear or stdout_needs_clear
438 def flush(self) -> None:
439 if self.wrapped_stdout: # pragma: no branch
440 if isinstance(self.stdout, WrappingIO): # pragma: no branch
441 try:
442 self.stdout._flush()
443 except io.UnsupportedOperation: # pragma: no cover
444 self.wrapped_stdout = False
445 logger.warning(
446 'Disabling stdout redirection, %r is not seekable',
447 sys.stdout,
448 )
450 if self.wrapped_stderr: # pragma: no branch
451 if isinstance(self.stderr, WrappingIO): # pragma: no branch
452 try:
453 self.stderr._flush()
454 except io.UnsupportedOperation: # pragma: no cover
455 self.wrapped_stderr = False
456 logger.warning(
457 'Disabling stderr redirection, %r is not seekable',
458 sys.stderr,
459 )
461 def excepthook(self, exc_type, exc_value, exc_traceback):
462 self.original_excepthook(exc_type, exc_value, exc_traceback)
463 self.flush()
466class AttributeDict(dict):
467 '''
468 A dict that can be accessed with .attribute
470 >>> attrs = AttributeDict(spam=123)
472 # Reading
474 >>> attrs['spam']
475 123
476 >>> attrs.spam
477 123
479 # Read after update using attribute
481 >>> attrs.spam = 456
482 >>> attrs['spam']
483 456
484 >>> attrs.spam
485 456
487 # Read after update using dict access
489 >>> attrs['spam'] = 123
490 >>> attrs['spam']
491 123
492 >>> attrs.spam
493 123
495 # Read after update using dict access
497 >>> del attrs.spam
498 >>> attrs['spam']
499 Traceback (most recent call last):
500 ...
501 KeyError: 'spam'
502 >>> attrs.spam
503 Traceback (most recent call last):
504 ...
505 AttributeError: No such attribute: spam
506 >>> del attrs.spam
507 Traceback (most recent call last):
508 ...
509 AttributeError: No such attribute: spam
510 '''
512 def __getattr__(self, name: str) -> int:
513 if name in self:
514 return self[name]
515 else:
516 raise AttributeError("No such attribute: " + name)
518 def __setattr__(self, name: str, value: int) -> None:
519 self[name] = value
521 def __delattr__(self, name: str) -> None:
522 if name in self:
523 del self[name]
524 else:
525 raise AttributeError("No such attribute: " + name)
528logger = logging.getLogger(__name__)
529streams = StreamWrapper()
530atexit.register(streams.flush)