Coverage for /Volumes/workspace/python-progressbar/.tox/py39/lib/python3.9/site-packages/progressbar/utils.py: 95%

225 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-01 16:14 +0100

1from __future__ import annotations 

2 

3import atexit 

4import datetime 

5import io 

6import logging 

7import os 

8import re 

9import sys 

10from types import TracebackType 

11from typing import Iterable, Iterator, Type 

12 

13from python_utils import types 

14from python_utils.converters import scale_1024 

15from python_utils.terminal import get_terminal_size 

16from python_utils.time import epoch, format_time, timedelta_to_seconds 

17 

18from progressbar import base 

19 

20if types.TYPE_CHECKING: 

21 from .bar import ProgressBar, ProgressBarMixinBase 

22 

23assert timedelta_to_seconds 

24assert get_terminal_size 

25assert format_time 

26assert scale_1024 

27assert epoch 

28 

29StringT = types.TypeVar('StringT', bound=types.StringTypes) 

30 

31ANSI_TERMS = ( 

32 '([xe]|bv)term', 

33 '(sco)?ansi', 

34 'cygwin', 

35 'konsole', 

36 'linux', 

37 'rxvt', 

38 'screen', 

39 'tmux', 

40 'vt(10[02]|220|320)', 

41) 

42ANSI_TERM_RE = re.compile('^({})'.format('|'.join(ANSI_TERMS)), re.IGNORECASE) 

43 

44 

45def is_ansi_terminal( 

46 fd: base.IO, is_terminal: bool | None = None 

47) -> bool: # pragma: no cover 

48 if is_terminal is None: 

49 # Jupyter Notebooks define this variable and support progress bars 

50 if 'JPY_PARENT_PID' in os.environ: 

51 is_terminal = True 

52 # This works for newer versions of pycharm only. older versions there 

53 # is no way to check. 

54 elif os.environ.get('PYCHARM_HOSTED') == '1' and not os.environ.get( 

55 'PYTEST_CURRENT_TEST' 

56 ): 

57 is_terminal = True 

58 

59 if is_terminal is None: 

60 # check if we are writing to a terminal or not. typically a file object 

61 # is going to return False if the instance has been overridden and 

62 # isatty has not been defined we have no way of knowing so we will not 

63 # use ansi. ansi terminals will typically define one of the 2 

64 # environment variables. 

65 try: 

66 is_tty = fd.isatty() 

67 # Try and match any of the huge amount of Linux/Unix ANSI consoles 

68 if is_tty and ANSI_TERM_RE.match(os.environ.get('TERM', '')): 

69 is_terminal = True 

70 # ANSICON is a Windows ANSI compatible console 

71 elif 'ANSICON' in os.environ: 

72 is_terminal = True 

73 else: 

74 is_terminal = None 

75 except Exception: 

76 is_terminal = False 

77 

78 return bool(is_terminal) 

79 

80 

81def is_terminal(fd: base.IO, is_terminal: bool | None = None) -> bool: 

82 if is_terminal is None: 

83 # Full ansi support encompasses what we expect from a terminal 

84 is_terminal = is_ansi_terminal(fd) or None 

85 

86 if is_terminal is None: 

87 # Allow a environment variable override 

88 is_terminal = env_flag('PROGRESSBAR_IS_TERMINAL', None) 

89 

90 if is_terminal is None: # pragma: no cover 

91 # Bare except because a lot can go wrong on different systems. If we do 

92 # get a TTY we know this is a valid terminal 

93 try: 

94 is_terminal = fd.isatty() 

95 except Exception: 

96 is_terminal = False 

97 

98 return bool(is_terminal) 

99 

100 

101def deltas_to_seconds( 

102 *deltas, 

103 default: types.Optional[types.Type[ValueError]] = ValueError, 

104) -> int | float | None: 

105 ''' 

106 Convert timedeltas and seconds as int to seconds as float while coalescing 

107 

108 >>> deltas_to_seconds(datetime.timedelta(seconds=1, milliseconds=234)) 

109 1.234 

110 >>> deltas_to_seconds(123) 

111 123.0 

112 >>> deltas_to_seconds(1.234) 

113 1.234 

114 >>> deltas_to_seconds(None, 1.234) 

115 1.234 

116 >>> deltas_to_seconds(0, 1.234) 

117 0.0 

118 >>> deltas_to_seconds() 

119 Traceback (most recent call last): 

120 ... 

121 ValueError: No valid deltas passed to `deltas_to_seconds` 

122 >>> deltas_to_seconds(None) 

123 Traceback (most recent call last): 

124 ... 

125 ValueError: No valid deltas passed to `deltas_to_seconds` 

126 >>> deltas_to_seconds(default=0.0) 

127 0.0 

128 ''' 

129 for delta in deltas: 

130 if delta is None: 

131 continue 

132 if isinstance(delta, datetime.timedelta): 

133 return timedelta_to_seconds(delta) 

134 elif not isinstance(delta, float): 

135 return float(delta) 

136 else: 

137 return delta 

138 

139 if default is ValueError: 

140 raise ValueError('No valid deltas passed to `deltas_to_seconds`') 

141 else: 

142 # mypy doesn't understand the `default is ValueError` check 

143 return default # type: ignore 

144 

145 

146def no_color(value: StringT) -> StringT: 

147 ''' 

148 Return the `value` without ANSI escape codes 

149 

150 >>> no_color(b'\u001b[1234]abc') == b'abc' 

151 True 

152 >>> str(no_color(u'\u001b[1234]abc')) 

153 'abc' 

154 >>> str(no_color('\u001b[1234]abc')) 

155 'abc' 

156 ''' 

157 if isinstance(value, bytes): 

158 pattern: bytes = '\\\u001b\\[.*?[@-~]'.encode() 

159 return re.sub(pattern, b'', value) # type: ignore 

160 else: 

161 return re.sub(u'\x1b\\[.*?[@-~]', '', value) # type: ignore 

162 

163 

164def len_color(value: types.StringTypes) -> int: 

165 ''' 

166 Return the length of `value` without ANSI escape codes 

167 

168 >>> len_color(b'\u001b[1234]abc') 

169 3 

170 >>> len_color(u'\u001b[1234]abc') 

171 3 

172 >>> len_color('\u001b[1234]abc') 

173 3 

174 ''' 

175 return len(no_color(value)) 

176 

177 

178def env_flag(name: str, default: bool | None = None) -> bool | None: 

179 ''' 

180 Accepts environt variables formatted as y/n, yes/no, 1/0, true/false, 

181 on/off, and returns it as a boolean 

182 

183 If the environment variable is not defined, or has an unknown value, 

184 returns `default` 

185 ''' 

186 v = os.getenv(name) 

187 if v and v.lower() in ('y', 'yes', 't', 'true', 'on', '1'): 

188 return True 

189 if v and v.lower() in ('n', 'no', 'f', 'false', 'off', '0'): 

190 return False 

191 return default 

192 

193 

194class WrappingIO: 

195 buffer: io.StringIO 

196 target: base.IO 

197 capturing: bool 

198 listeners: set 

199 needs_clear: bool = False 

200 

201 def __init__( 

202 self, 

203 target: base.IO, 

204 capturing: bool = False, 

205 listeners: types.Optional[types.Set[ProgressBar]] = None, 

206 ) -> None: 

207 self.buffer = io.StringIO() 

208 self.target = target 

209 self.capturing = capturing 

210 self.listeners = listeners or set() 

211 self.needs_clear = False 

212 

213 def write(self, value: str) -> int: 

214 ret = 0 

215 if self.capturing: 

216 ret += self.buffer.write(value) 

217 if '\n' in value: # pragma: no branch 

218 self.needs_clear = True 

219 for listener in self.listeners: # pragma: no branch 

220 listener.update() 

221 else: 

222 ret += self.target.write(value) 

223 if '\n' in value: # pragma: no branch 

224 self.flush_target() 

225 

226 return ret 

227 

228 def flush(self) -> None: 

229 self.buffer.flush() 

230 

231 def _flush(self) -> None: 

232 value = self.buffer.getvalue() 

233 if value: 

234 self.flush() 

235 self.target.write(value) 

236 self.buffer.seek(0) 

237 self.buffer.truncate(0) 

238 self.needs_clear = False 

239 

240 # when explicitly flushing, always flush the target as well 

241 self.flush_target() 

242 

243 def flush_target(self) -> None: # pragma: no cover 

244 if not self.target.closed and getattr(self.target, 'flush'): 

245 self.target.flush() 

246 

247 def __enter__(self) -> WrappingIO: 

248 return self 

249 

250 def fileno(self) -> int: 

251 return self.target.fileno() 

252 

253 def isatty(self) -> bool: 

254 return self.target.isatty() 

255 

256 def read(self, n: int = -1) -> str: 

257 return self.target.read(n) 

258 

259 def readable(self) -> bool: 

260 return self.target.readable() 

261 

262 def readline(self, limit: int = -1) -> str: 

263 return self.target.readline(limit) 

264 

265 def readlines(self, hint: int = -1) -> list[str]: 

266 return self.target.readlines(hint) 

267 

268 def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: 

269 return self.target.seek(offset, whence) 

270 

271 def seekable(self) -> bool: 

272 return self.target.seekable() 

273 

274 def tell(self) -> int: 

275 return self.target.tell() 

276 

277 def truncate(self, size: types.Optional[int] = None) -> int: 

278 return self.target.truncate(size) 

279 

280 def writable(self) -> bool: 

281 return self.target.writable() 

282 

283 def writelines(self, lines: Iterable[str]) -> None: 

284 return self.target.writelines(lines) 

285 

286 def close(self) -> None: 

287 self.flush() 

288 self.target.close() 

289 

290 def __next__(self) -> str: 

291 return self.target.__next__() 

292 

293 def __iter__(self) -> Iterator[str]: 

294 return self.target.__iter__() 

295 

296 def __exit__( 

297 self, 

298 __t: Type[BaseException] | None, 

299 __value: BaseException | None, 

300 __traceback: TracebackType | None, 

301 ) -> None: 

302 self.close() 

303 

304 

305class StreamWrapper: 

306 '''Wrap stdout and stderr globally''' 

307 

308 stdout: base.TextIO | WrappingIO 

309 stderr: base.TextIO | WrappingIO 

310 original_excepthook: types.Callable[ 

311 [ 

312 types.Type[BaseException], 

313 BaseException, 

314 TracebackType | None, 

315 ], 

316 None, 

317 ] 

318 # original_excepthook: types.Callable[ 

319 # [ 

320 # types.Type[BaseException], 

321 # BaseException, TracebackType | None, 

322 # ], None] | None 

323 wrapped_stdout: int = 0 

324 wrapped_stderr: int = 0 

325 wrapped_excepthook: int = 0 

326 capturing: int = 0 

327 listeners: set 

328 

329 def __init__(self): 

330 self.stdout = self.original_stdout = sys.stdout 

331 self.stderr = self.original_stderr = sys.stderr 

332 self.original_excepthook = sys.excepthook 

333 self.wrapped_stdout = 0 

334 self.wrapped_stderr = 0 

335 self.wrapped_excepthook = 0 

336 self.capturing = 0 

337 self.listeners = set() 

338 

339 if env_flag('WRAP_STDOUT', default=False): # pragma: no cover 

340 self.wrap_stdout() 

341 

342 if env_flag('WRAP_STDERR', default=False): # pragma: no cover 

343 self.wrap_stderr() 

344 

345 def start_capturing(self, bar: ProgressBarMixinBase | None = None) -> None: 

346 if bar: # pragma: no branch 

347 self.listeners.add(bar) 

348 

349 self.capturing += 1 

350 self.update_capturing() 

351 

352 def stop_capturing(self, bar: ProgressBarMixinBase | None = None) -> None: 

353 if bar: # pragma: no branch 

354 try: 

355 self.listeners.remove(bar) 

356 except KeyError: 

357 pass 

358 

359 self.capturing -= 1 

360 self.update_capturing() 

361 

362 def update_capturing(self) -> None: # pragma: no cover 

363 if isinstance(self.stdout, WrappingIO): 

364 self.stdout.capturing = self.capturing > 0 

365 

366 if isinstance(self.stderr, WrappingIO): 

367 self.stderr.capturing = self.capturing > 0 

368 

369 if self.capturing <= 0: 

370 self.flush() 

371 

372 def wrap(self, stdout: bool = False, stderr: bool = False) -> None: 

373 if stdout: 

374 self.wrap_stdout() 

375 

376 if stderr: 

377 self.wrap_stderr() 

378 

379 def wrap_stdout(self) -> base.IO: 

380 self.wrap_excepthook() 

381 

382 if not self.wrapped_stdout: 

383 self.stdout = sys.stdout = WrappingIO( # type: ignore 

384 self.original_stdout, listeners=self.listeners 

385 ) 

386 self.wrapped_stdout += 1 

387 

388 return sys.stdout 

389 

390 def wrap_stderr(self) -> base.IO: 

391 self.wrap_excepthook() 

392 

393 if not self.wrapped_stderr: 

394 self.stderr = sys.stderr = WrappingIO( # type: ignore 

395 self.original_stderr, listeners=self.listeners 

396 ) 

397 self.wrapped_stderr += 1 

398 

399 return sys.stderr 

400 

401 def unwrap_excepthook(self) -> None: 

402 if self.wrapped_excepthook: 

403 self.wrapped_excepthook -= 1 

404 sys.excepthook = self.original_excepthook 

405 

406 def wrap_excepthook(self) -> None: 

407 if not self.wrapped_excepthook: 

408 logger.debug('wrapping excepthook') 

409 self.wrapped_excepthook += 1 

410 sys.excepthook = self.excepthook 

411 

412 def unwrap(self, stdout: bool = False, stderr: bool = False) -> None: 

413 if stdout: 

414 self.unwrap_stdout() 

415 

416 if stderr: 

417 self.unwrap_stderr() 

418 

419 def unwrap_stdout(self) -> None: 

420 if self.wrapped_stdout > 1: 

421 self.wrapped_stdout -= 1 

422 else: 

423 sys.stdout = self.original_stdout 

424 self.wrapped_stdout = 0 

425 

426 def unwrap_stderr(self) -> None: 

427 if self.wrapped_stderr > 1: 

428 self.wrapped_stderr -= 1 

429 else: 

430 sys.stderr = self.original_stderr 

431 self.wrapped_stderr = 0 

432 

433 def needs_clear(self) -> bool: # pragma: no cover 

434 stdout_needs_clear = getattr(self.stdout, 'needs_clear', False) 

435 stderr_needs_clear = getattr(self.stderr, 'needs_clear', False) 

436 return stderr_needs_clear or stdout_needs_clear 

437 

438 def flush(self) -> None: 

439 if self.wrapped_stdout: # pragma: no branch 

440 if isinstance(self.stdout, WrappingIO): # pragma: no branch 

441 try: 

442 self.stdout._flush() 

443 except io.UnsupportedOperation: # pragma: no cover 

444 self.wrapped_stdout = False 

445 logger.warning( 

446 'Disabling stdout redirection, %r is not seekable', 

447 sys.stdout, 

448 ) 

449 

450 if self.wrapped_stderr: # pragma: no branch 

451 if isinstance(self.stderr, WrappingIO): # pragma: no branch 

452 try: 

453 self.stderr._flush() 

454 except io.UnsupportedOperation: # pragma: no cover 

455 self.wrapped_stderr = False 

456 logger.warning( 

457 'Disabling stderr redirection, %r is not seekable', 

458 sys.stderr, 

459 ) 

460 

461 def excepthook(self, exc_type, exc_value, exc_traceback): 

462 self.original_excepthook(exc_type, exc_value, exc_traceback) 

463 self.flush() 

464 

465 

466class AttributeDict(dict): 

467 ''' 

468 A dict that can be accessed with .attribute 

469 

470 >>> attrs = AttributeDict(spam=123) 

471 

472 # Reading 

473 

474 >>> attrs['spam'] 

475 123 

476 >>> attrs.spam 

477 123 

478 

479 # Read after update using attribute 

480 

481 >>> attrs.spam = 456 

482 >>> attrs['spam'] 

483 456 

484 >>> attrs.spam 

485 456 

486 

487 # Read after update using dict access 

488 

489 >>> attrs['spam'] = 123 

490 >>> attrs['spam'] 

491 123 

492 >>> attrs.spam 

493 123 

494 

495 # Read after update using dict access 

496 

497 >>> del attrs.spam 

498 >>> attrs['spam'] 

499 Traceback (most recent call last): 

500 ... 

501 KeyError: 'spam' 

502 >>> attrs.spam 

503 Traceback (most recent call last): 

504 ... 

505 AttributeError: No such attribute: spam 

506 >>> del attrs.spam 

507 Traceback (most recent call last): 

508 ... 

509 AttributeError: No such attribute: spam 

510 ''' 

511 

512 def __getattr__(self, name: str) -> int: 

513 if name in self: 

514 return self[name] 

515 else: 

516 raise AttributeError("No such attribute: " + name) 

517 

518 def __setattr__(self, name: str, value: int) -> None: 

519 self[name] = value 

520 

521 def __delattr__(self, name: str) -> None: 

522 if name in self: 

523 del self[name] 

524 else: 

525 raise AttributeError("No such attribute: " + name) 

526 

527 

528logger = logging.getLogger(__name__) 

529streams = StreamWrapper() 

530atexit.register(streams.flush)