77import os
88import re
99import sys
10+ from types import TracebackType
11+ from typing import Iterable , Iterator , Type
1012
1113from python_utils import types
1214from python_utils .converters import scale_1024
1315from python_utils .terminal import get_terminal_size
14- from python_utils .time import epoch
15- from python_utils .time import format_time
16- from python_utils .time import timedelta_to_seconds
16+ from python_utils .time import epoch , format_time , timedelta_to_seconds
1717
1818if types .TYPE_CHECKING :
1919 from .bar import ProgressBar
3939
4040
4141def is_ansi_terminal (fd : types .IO , is_terminal : bool | None = None ) \
42- -> bool : # pragma: no cover
42+ -> bool : # pragma: no cover
4343 if is_terminal is None :
4444 # Jupyter Notebooks define this variable and support progress bars
4545 if 'JPY_PARENT_PID' in os .environ :
@@ -68,13 +68,13 @@ def is_ansi_terminal(fd: types.IO, is_terminal: bool | None = None) \
6868 except Exception :
6969 is_terminal = False
7070
71- return is_terminal
71+ return bool ( is_terminal )
7272
7373
7474def is_terminal (fd : types .IO , is_terminal : bool | None = None ) -> bool :
7575 if is_terminal is None :
7676 # Full ansi support encompasses what we expect from a terminal
77- is_terminal = is_ansi_terminal (True ) or None
77+ is_terminal = is_ansi_terminal (fd ) or None
7878
7979 if is_terminal is None :
8080 # Allow a environment variable override
@@ -88,11 +88,13 @@ def is_terminal(fd: types.IO, is_terminal: bool | None = None) -> bool:
8888 except Exception :
8989 is_terminal = False
9090
91- return is_terminal
91+ return bool ( is_terminal )
9292
9393
94- def deltas_to_seconds (* deltas ,
95- ** kwargs ) -> int | float | None : # default=ValueError):
94+ def deltas_to_seconds (
95+ * deltas ,
96+ ** kwargs
97+ ) -> int | float | None : # default=ValueError):
9698 '''
9799 Convert timedeltas and seconds as int to seconds as float while coalescing
98100
@@ -148,15 +150,9 @@ def no_color(value: types.StringTypes) -> types.StringTypes:
148150 'abc'
149151 '''
150152 if isinstance (value , bytes ):
151- pattern = '\\ \u001b \\ [.*?[@-~]'
152- pattern = pattern .encode ()
153- replace = b''
154- assert isinstance (pattern , bytes )
153+ return re .sub ('\\ \u001b \\ [.*?[@-~]' .encode (), b'' , value )
155154 else :
156- pattern = u'\x1b \\ [.*?[@-~]'
157- replace = ''
158-
159- return re .sub (pattern , replace , value )
155+ return re .sub (u'\x1b \\ [.*?[@-~]' , '' , value )
160156
161157
162158def len_color (value : types .StringTypes ) -> int :
@@ -190,30 +186,37 @@ def env_flag(name: str, default: bool | None = None) -> bool | None:
190186
191187
192188class WrappingIO :
193-
194- def __init__ (self , target : types .IO , capturing : bool = False ,
195- listeners : types .Set [ProgressBar ] = None ) -> None :
189+ buffer : io .StringIO
190+ target : types .IO
191+ capturing : bool
192+ listeners : set
193+ needs_clear : bool = False
194+
195+ def __init__ (
196+ self , target : types .IO , capturing : bool = False ,
197+ listeners : types .Set [ProgressBar ] = None
198+ ) -> None :
196199 self .buffer = io .StringIO ()
197200 self .target = target
198201 self .capturing = capturing
199202 self .listeners = listeners or set ()
200203 self .needs_clear = False
201204
202- def __getattr__ (self , name ): # pragma: no cover
203- return getattr (self .target , name )
204-
205- def write (self , value : str ) -> None :
205+ def write (self , value : str ) -> int :
206+ ret = 0
206207 if self .capturing :
207- self .buffer .write (value )
208+ ret += self .buffer .write (value )
208209 if '\n ' in value : # pragma: no branch
209210 self .needs_clear = True
210211 for listener in self .listeners : # pragma: no branch
211212 listener .update ()
212213 else :
213- self .target .write (value )
214+ ret += self .target .write (value )
214215 if '\n ' in value : # pragma: no branch
215216 self .flush_target ()
216217
218+ return ret
219+
217220 def flush (self ) -> None :
218221 self .buffer .flush ()
219222
@@ -233,9 +236,80 @@ def flush_target(self) -> None: # pragma: no cover
233236 if not self .target .closed and getattr (self .target , 'flush' ):
234237 self .target .flush ()
235238
239+ def __enter__ (self ) -> WrappingIO :
240+ return self
241+
242+ def fileno (self ) -> int :
243+ return self .target .fileno ()
244+
245+ def isatty (self ) -> bool :
246+ return self .target .isatty ()
247+
248+ def read (self , n : int = - 1 ) -> str :
249+ return self .target .read (n )
250+
251+ def readable (self ) -> bool :
252+ return self .target .readable ()
253+
254+ def readline (self , limit : int = - 1 ) -> str :
255+ return self .target .readline (limit )
256+
257+ def readlines (self , hint : int = - 1 ) -> list [str ]:
258+ return self .target .readlines (hint )
259+
260+ def seek (self , offset : int , whence : int = os .SEEK_SET ) -> int :
261+ return self .target .seek (offset , whence )
262+
263+ def seekable (self ) -> bool :
264+ return self .target .seekable ()
265+
266+ def tell (self ) -> int :
267+ return self .target .tell ()
268+
269+ def truncate (self , size : types .Optional [int ] = None ) -> int :
270+ return self .target .truncate (size )
271+
272+ def writable (self ) -> bool :
273+ return self .target .writable ()
274+
275+ def writelines (self , lines : Iterable [str ]) -> None :
276+ return self .target .writelines (lines )
277+
278+ def close (self ) -> None :
279+ self .flush ()
280+ self .target .close ()
281+
282+ def __next__ (self ) -> str :
283+ return self .target .__next__ ()
284+
285+ def __iter__ (self ) -> Iterator [str ]:
286+ return self .target .__iter__ ()
287+
288+ def __exit__ (
289+ self ,
290+ __t : Type [BaseException ] | None ,
291+ __value : BaseException | None ,
292+ __traceback : TracebackType | None
293+ ) -> None :
294+ self .close ()
295+
236296
237297class StreamWrapper :
238298 '''Wrap stdout and stderr globally'''
299+ stdout : types .Union [types .TextIO , WrappingIO ]
300+ stderr : types .Union [types .TextIO , WrappingIO ]
301+ original_excepthook : types .Callable [
302+ [
303+ types .Optional [
304+ types .Type [BaseException ]],
305+ types .Optional [BaseException ],
306+ types .Optional [TracebackType ],
307+ ], None ]
308+ wrapped_stdout : int = 0
309+ wrapped_stderr : int = 0
310+ wrapped_excepthook : int = 0
311+ capturing : int = 0
312+ listeners : set
239313
240314 def __init__ (self ):
241315 self .stdout = self .original_stdout = sys .stdout
@@ -291,8 +365,10 @@ def wrap_stdout(self) -> types.IO:
291365 self .wrap_excepthook ()
292366
293367 if not self .wrapped_stdout :
294- self .stdout = sys .stdout = WrappingIO (self .original_stdout ,
295- listeners = self .listeners )
368+ self .stdout = sys .stdout = WrappingIO ( # type: ignore
369+ self .original_stdout ,
370+ listeners = self .listeners
371+ )
296372 self .wrapped_stdout += 1
297373
298374 return sys .stdout
@@ -301,8 +377,10 @@ def wrap_stderr(self) -> types.IO:
301377 self .wrap_excepthook ()
302378
303379 if not self .wrapped_stderr :
304- self .stderr = sys .stderr = WrappingIO (self .original_stderr ,
305- listeners = self .listeners )
380+ self .stderr = sys .stderr = WrappingIO ( # type: ignore
381+ self .original_stderr ,
382+ listeners = self .listeners
383+ )
306384 self .wrapped_stderr += 1
307385
308386 return sys .stderr
@@ -346,22 +424,26 @@ def needs_clear(self) -> bool: # pragma: no cover
346424
347425 def flush (self ) -> None :
348426 if self .wrapped_stdout : # pragma: no branch
349- try :
350- self .stdout ._flush ()
351- except (io .UnsupportedOperation ,
352- AttributeError ): # pragma: no cover
353- self .wrapped_stdout = False
354- logger .warn ('Disabling stdout redirection, %r is not seekable' ,
355- sys .stdout )
427+ if isinstance (self .stdout , WrappingIO ): # pragma: no branch
428+ try :
429+ self .stdout ._flush ()
430+ except io .UnsupportedOperation : # pragma: no cover
431+ self .wrapped_stdout = False
432+ logger .warning (
433+ 'Disabling stdout redirection, %r is not seekable' ,
434+ sys .stdout
435+ )
356436
357437 if self .wrapped_stderr : # pragma: no branch
358- try :
359- self .stderr ._flush ()
360- except (io .UnsupportedOperation ,
361- AttributeError ): # pragma: no cover
362- self .wrapped_stderr = False
363- logger .warn ('Disabling stderr redirection, %r is not seekable' ,
364- sys .stderr )
438+ if isinstance (self .stderr , WrappingIO ): # pragma: no branch
439+ try :
440+ self .stderr ._flush ()
441+ except io .UnsupportedOperation : # pragma: no cover
442+ self .wrapped_stderr = False
443+ logger .warning (
444+ 'Disabling stderr redirection, %r is not seekable' ,
445+ sys .stderr
446+ )
365447
366448 def excepthook (self , exc_type , exc_value , exc_traceback ):
367449 self .original_excepthook (exc_type , exc_value , exc_traceback )
0 commit comments