Skip to content

Commit 7dfd7cc

Browse files
committed
fixed type issues
1 parent b2e2777 commit 7dfd7cc

1 file changed

Lines changed: 126 additions & 44 deletions

File tree

progressbar/utils.py

Lines changed: 126 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import os
88
import re
99
import sys
10+
from types import TracebackType
11+
from typing import Iterable, Iterator, Type
1012

1113
from python_utils import types
1214
from python_utils.converters import scale_1024
1315
from python_utils.terminal import get_terminal_size
14-
from python_utils.time import epoch
15-
from python_utils.time import format_time
16-
from python_utils.time import timedelta_to_seconds
16+
from python_utils.time import epoch, format_time, timedelta_to_seconds
1717

1818
if types.TYPE_CHECKING:
1919
from .bar import ProgressBar
@@ -39,7 +39,7 @@
3939

4040

4141
def is_ansi_terminal(fd: types.IO, is_terminal: bool | None = None) \
42-
-> bool: # pragma: no cover
42+
-> bool: # pragma: no cover
4343
if is_terminal is None:
4444
# Jupyter Notebooks define this variable and support progress bars
4545
if 'JPY_PARENT_PID' in os.environ:
@@ -68,13 +68,13 @@ def is_ansi_terminal(fd: types.IO, is_terminal: bool | None = None) \
6868
except Exception:
6969
is_terminal = False
7070

71-
return is_terminal
71+
return bool(is_terminal)
7272

7373

7474
def is_terminal(fd: types.IO, is_terminal: bool | None = None) -> bool:
7575
if is_terminal is None:
7676
# Full ansi support encompasses what we expect from a terminal
77-
is_terminal = is_ansi_terminal(True) or None
77+
is_terminal = is_ansi_terminal(fd) or None
7878

7979
if is_terminal is None:
8080
# Allow a environment variable override
@@ -88,11 +88,13 @@ def is_terminal(fd: types.IO, is_terminal: bool | None = None) -> bool:
8888
except Exception:
8989
is_terminal = False
9090

91-
return is_terminal
91+
return bool(is_terminal)
9292

9393

94-
def deltas_to_seconds(*deltas,
95-
**kwargs) -> int | float | None: # default=ValueError):
94+
def deltas_to_seconds(
95+
*deltas,
96+
**kwargs
97+
) -> int | float | None: # default=ValueError):
9698
'''
9799
Convert timedeltas and seconds as int to seconds as float while coalescing
98100
@@ -148,15 +150,9 @@ def no_color(value: types.StringTypes) -> types.StringTypes:
148150
'abc'
149151
'''
150152
if isinstance(value, bytes):
151-
pattern = '\\\u001b\\[.*?[@-~]'
152-
pattern = pattern.encode()
153-
replace = b''
154-
assert isinstance(pattern, bytes)
153+
return re.sub('\\\u001b\\[.*?[@-~]'.encode(), b'', value)
155154
else:
156-
pattern = u'\x1b\\[.*?[@-~]'
157-
replace = ''
158-
159-
return re.sub(pattern, replace, value)
155+
return re.sub(u'\x1b\\[.*?[@-~]', '', value)
160156

161157

162158
def len_color(value: types.StringTypes) -> int:
@@ -190,30 +186,37 @@ def env_flag(name: str, default: bool | None = None) -> bool | None:
190186

191187

192188
class WrappingIO:
193-
194-
def __init__(self, target: types.IO, capturing: bool = False,
195-
listeners: types.Set[ProgressBar] = None) -> None:
189+
buffer: io.StringIO
190+
target: types.IO
191+
capturing: bool
192+
listeners: set
193+
needs_clear: bool = False
194+
195+
def __init__(
196+
self, target: types.IO, capturing: bool = False,
197+
listeners: types.Set[ProgressBar] = None
198+
) -> None:
196199
self.buffer = io.StringIO()
197200
self.target = target
198201
self.capturing = capturing
199202
self.listeners = listeners or set()
200203
self.needs_clear = False
201204

202-
def __getattr__(self, name): # pragma: no cover
203-
return getattr(self.target, name)
204-
205-
def write(self, value: str) -> None:
205+
def write(self, value: str) -> int:
206+
ret = 0
206207
if self.capturing:
207-
self.buffer.write(value)
208+
ret += self.buffer.write(value)
208209
if '\n' in value: # pragma: no branch
209210
self.needs_clear = True
210211
for listener in self.listeners: # pragma: no branch
211212
listener.update()
212213
else:
213-
self.target.write(value)
214+
ret += self.target.write(value)
214215
if '\n' in value: # pragma: no branch
215216
self.flush_target()
216217

218+
return ret
219+
217220
def flush(self) -> None:
218221
self.buffer.flush()
219222

@@ -233,9 +236,80 @@ def flush_target(self) -> None: # pragma: no cover
233236
if not self.target.closed and getattr(self.target, 'flush'):
234237
self.target.flush()
235238

239+
def __enter__(self) -> WrappingIO:
240+
return self
241+
242+
def fileno(self) -> int:
243+
return self.target.fileno()
244+
245+
def isatty(self) -> bool:
246+
return self.target.isatty()
247+
248+
def read(self, n: int = -1) -> str:
249+
return self.target.read(n)
250+
251+
def readable(self) -> bool:
252+
return self.target.readable()
253+
254+
def readline(self, limit: int = -1) -> str:
255+
return self.target.readline(limit)
256+
257+
def readlines(self, hint: int = -1) -> list[str]:
258+
return self.target.readlines(hint)
259+
260+
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
261+
return self.target.seek(offset, whence)
262+
263+
def seekable(self) -> bool:
264+
return self.target.seekable()
265+
266+
def tell(self) -> int:
267+
return self.target.tell()
268+
269+
def truncate(self, size: types.Optional[int] = None) -> int:
270+
return self.target.truncate(size)
271+
272+
def writable(self) -> bool:
273+
return self.target.writable()
274+
275+
def writelines(self, lines: Iterable[str]) -> None:
276+
return self.target.writelines(lines)
277+
278+
def close(self) -> None:
279+
self.flush()
280+
self.target.close()
281+
282+
def __next__(self) -> str:
283+
return self.target.__next__()
284+
285+
def __iter__(self) -> Iterator[str]:
286+
return self.target.__iter__()
287+
288+
def __exit__(
289+
self,
290+
__t: Type[BaseException] | None,
291+
__value: BaseException | None,
292+
__traceback: TracebackType | None
293+
) -> None:
294+
self.close()
295+
236296

237297
class StreamWrapper:
238298
'''Wrap stdout and stderr globally'''
299+
stdout: types.Union[types.TextIO, WrappingIO]
300+
stderr: types.Union[types.TextIO, WrappingIO]
301+
original_excepthook: types.Callable[
302+
[
303+
types.Optional[
304+
types.Type[BaseException]],
305+
types.Optional[BaseException],
306+
types.Optional[TracebackType],
307+
], None]
308+
wrapped_stdout: int = 0
309+
wrapped_stderr: int = 0
310+
wrapped_excepthook: int = 0
311+
capturing: int = 0
312+
listeners: set
239313

240314
def __init__(self):
241315
self.stdout = self.original_stdout = sys.stdout
@@ -291,8 +365,10 @@ def wrap_stdout(self) -> types.IO:
291365
self.wrap_excepthook()
292366

293367
if not self.wrapped_stdout:
294-
self.stdout = sys.stdout = WrappingIO(self.original_stdout,
295-
listeners=self.listeners)
368+
self.stdout = sys.stdout = WrappingIO( # type: ignore
369+
self.original_stdout,
370+
listeners=self.listeners
371+
)
296372
self.wrapped_stdout += 1
297373

298374
return sys.stdout
@@ -301,8 +377,10 @@ def wrap_stderr(self) -> types.IO:
301377
self.wrap_excepthook()
302378

303379
if not self.wrapped_stderr:
304-
self.stderr = sys.stderr = WrappingIO(self.original_stderr,
305-
listeners=self.listeners)
380+
self.stderr = sys.stderr = WrappingIO( # type: ignore
381+
self.original_stderr,
382+
listeners=self.listeners
383+
)
306384
self.wrapped_stderr += 1
307385

308386
return sys.stderr
@@ -346,22 +424,26 @@ def needs_clear(self) -> bool: # pragma: no cover
346424

347425
def flush(self) -> None:
348426
if self.wrapped_stdout: # pragma: no branch
349-
try:
350-
self.stdout._flush()
351-
except (io.UnsupportedOperation,
352-
AttributeError): # pragma: no cover
353-
self.wrapped_stdout = False
354-
logger.warn('Disabling stdout redirection, %r is not seekable',
355-
sys.stdout)
427+
if isinstance(self.stdout, WrappingIO): # pragma: no branch
428+
try:
429+
self.stdout._flush()
430+
except io.UnsupportedOperation: # pragma: no cover
431+
self.wrapped_stdout = False
432+
logger.warning(
433+
'Disabling stdout redirection, %r is not seekable',
434+
sys.stdout
435+
)
356436

357437
if self.wrapped_stderr: # pragma: no branch
358-
try:
359-
self.stderr._flush()
360-
except (io.UnsupportedOperation,
361-
AttributeError): # pragma: no cover
362-
self.wrapped_stderr = False
363-
logger.warn('Disabling stderr redirection, %r is not seekable',
364-
sys.stderr)
438+
if isinstance(self.stderr, WrappingIO): # pragma: no branch
439+
try:
440+
self.stderr._flush()
441+
except io.UnsupportedOperation: # pragma: no cover
442+
self.wrapped_stderr = False
443+
logger.warning(
444+
'Disabling stderr redirection, %r is not seekable',
445+
sys.stderr
446+
)
365447

366448
def excepthook(self, exc_type, exc_value, exc_traceback):
367449
self.original_excepthook(exc_type, exc_value, exc_traceback)

0 commit comments

Comments
 (0)