Skip to content

Commit a51decc

Browse files
committed
Added type hints to widgets
1 parent 36c796f commit a51decc

1 file changed

Lines changed: 62 additions & 41 deletions

File tree

progressbar/widgets.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
3+
34
import abc
45
import datetime
56
import functools
@@ -12,7 +13,7 @@
1213
from . import base, utils
1314

1415
if types.TYPE_CHECKING:
15-
from .bar import ProgressBar
16+
from .bar import ProgressBarMixinBase
1617

1718
MAX_DATE = datetime.date.max
1819
MAX_TIME = datetime.time.max
@@ -120,15 +121,15 @@ def __init__(self, format: str, new_style: bool = False, **kwargs):
120121

121122
def get_format(
122123
self,
123-
progress: ProgressBar,
124+
progress: ProgressBarMixinBase,
124125
data: Data,
125126
format: types.Optional[str] = None,
126127
) -> str:
127128
return format or self.format
128129

129130
def __call__(
130131
self,
131-
progress: ProgressBar,
132+
progress: ProgressBarMixinBase,
132133
data: Data,
133134
format: types.Optional[str] = None,
134135
):
@@ -148,7 +149,7 @@ def __call__(
148149
class WidthWidgetMixin(abc.ABC):
149150
'''Mixing to make sure widgets are only visible if the screen is within a
150151
specified size range so the progressbar fits on both large and small
151-
screens..
152+
screens.
152153
153154
Variables available:
154155
- min_width: Only display the widget if at least `min_width` is left
@@ -174,7 +175,7 @@ def __init__(self, min_width=None, max_width=None, **kwargs):
174175
self.min_width = min_width
175176
self.max_width = max_width
176177

177-
def check_size(self, progress: 'ProgressBar'):
178+
def check_size(self, progress: ProgressBarMixinBase):
178179
if self.min_width and self.min_width > progress.term_width:
179180
return False
180181
elif self.max_width and self.max_width < progress.term_width:
@@ -190,7 +191,7 @@ class WidgetBase(WidthWidgetMixin, metaclass=abc.ABCMeta):
190191
be updated. The widget's size may change between calls, but the widget may
191192
display incorrectly if the size changes drastically and repeatedly.
192193
193-
The boolean INTERVAL informs the ProgressBar that it should be
194+
The INTERVAL timedelta informs the ProgressBar that it should be
194195
updated more often because it is time sensitive.
195196
196197
The widgets are only visible if the screen is within a
@@ -214,7 +215,7 @@ class WidgetBase(WidthWidgetMixin, metaclass=abc.ABCMeta):
214215
copy = True
215216

216217
@abc.abstractmethod
217-
def __call__(self, progress: ProgressBar, data: Data):
218+
def __call__(self, progress: ProgressBarMixinBase, data: Data):
218219
'''Updates the widget.
219220
220221
progress - a reference to the calling ProgressBar
@@ -232,7 +233,7 @@ class AutoWidthWidgetBase(WidgetBase, metaclass=abc.ABCMeta):
232233
@abc.abstractmethod
233234
def __call__(
234235
self,
235-
progress: ProgressBar,
236+
progress: ProgressBarMixinBase,
236237
data: Data,
237238
width: int = 0,
238239
):
@@ -281,7 +282,7 @@ def __init__(self, format: str, **kwargs):
281282

282283
def __call__(
283284
self,
284-
progress: ProgressBar,
285+
progress: ProgressBarMixinBase,
285286
data: Data,
286287
format: types.Optional[str] = None,
287288
):
@@ -350,16 +351,20 @@ def __init__(
350351
**kwargs,
351352
):
352353
self.samples = samples
353-
self.key_prefix = (self.__class__.__name__ or key_prefix) + '_'
354+
self.key_prefix = (
355+
key_prefix if key_prefix else self.__class__.__name__
356+
) + '_'
354357
TimeSensitiveWidgetBase.__init__(self, **kwargs)
355358

356-
def get_sample_times(self, progress: ProgressBar, data: Data):
359+
def get_sample_times(self, progress: ProgressBarMixinBase, data: Data):
357360
return progress.extra.setdefault(self.key_prefix + 'sample_times', [])
358361

359-
def get_sample_values(self, progress: ProgressBar, data: Data):
362+
def get_sample_values(self, progress: ProgressBarMixinBase, data: Data):
360363
return progress.extra.setdefault(self.key_prefix + 'sample_values', [])
361364

362-
def __call__(self, progress: ProgressBar, data: Data, delta: bool = False):
365+
def __call__(
366+
self, progress: ProgressBarMixinBase, data: Data, delta: bool = False
367+
):
363368
sample_times = self.get_sample_times(progress, data)
364369
sample_values = self.get_sample_values(progress, data)
365370

@@ -423,7 +428,7 @@ def __init__(
423428
self.format_NA = format_NA
424429

425430
def _calculate_eta(
426-
self, progress: ProgressBar, data: Data, value, elapsed
431+
self, progress: ProgressBarMixinBase, data: Data, value, elapsed
427432
):
428433
'''Updates the widget to show the ETA or total time when finished.'''
429434
if elapsed:
@@ -438,7 +443,7 @@ def _calculate_eta(
438443

439444
def __call__(
440445
self,
441-
progress: ProgressBar,
446+
progress: ProgressBarMixinBase,
442447
data: Data,
443448
value=None,
444449
elapsed=None,
@@ -484,7 +489,7 @@ class AbsoluteETA(ETA):
484489
'''Widget which attempts to estimate the absolute time of arrival.'''
485490

486491
def _calculate_eta(
487-
self, progress: ProgressBar, data: Data, value, elapsed
492+
self, progress: ProgressBarMixinBase, data: Data, value, elapsed
488493
):
489494
eta_seconds = ETA._calculate_eta(self, progress, data, value, elapsed)
490495
now = datetime.datetime.now()
@@ -522,7 +527,7 @@ def __init__(self, **kwargs):
522527

523528
def __call__(
524529
self,
525-
progress: ProgressBar,
530+
progress: ProgressBarMixinBase,
526531
data: Data,
527532
value=None,
528533
elapsed=None,
@@ -561,7 +566,7 @@ def __init__(
561566

562567
def __call__(
563568
self,
564-
progress: ProgressBar,
569+
progress: ProgressBarMixinBase,
565570
data: Data,
566571
format: types.Optional[str] = None,
567572
):
@@ -603,7 +608,7 @@ def _speed(self, value, elapsed):
603608

604609
def __call__(
605610
self,
606-
progress: ProgressBar,
611+
progress: ProgressBarMixinBase,
607612
data,
608613
value=None,
609614
total_seconds_elapsed=None,
@@ -650,7 +655,7 @@ def __init__(self, **kwargs):
650655

651656
def __call__(
652657
self,
653-
progress: ProgressBar,
658+
progress: ProgressBarMixinBase,
654659
data,
655660
value=None,
656661
total_seconds_elapsed=None,
@@ -682,7 +687,7 @@ def __init__(
682687
self.fill = create_marker(fill, self.fill_wrap) if fill else None
683688
WidgetBase.__init__(self, **kwargs)
684689

685-
def __call__(self, progress: ProgressBar, data: Data, width=None):
690+
def __call__(self, progress: ProgressBarMixinBase, data: Data, width=None):
686691
'''Updates the widget to show the next marker or the first marker when
687692
finished'''
688693

@@ -709,7 +714,7 @@ def __call__(self, progress: ProgressBar, data: Data, width=None):
709714
# cast fill to the same type as marker
710715
fill = type(marker)(fill)
711716

712-
return fill + marker
717+
return fill + marker # type: ignore
713718

714719

715720
# Alias for backwards compatibility
@@ -723,7 +728,9 @@ def __init__(self, format='%(value)d', **kwargs):
723728
FormatWidgetMixin.__init__(self, format=format, **kwargs)
724729
WidgetBase.__init__(self, format=format, **kwargs)
725730

726-
def __call__(self, progress: ProgressBar, data: Data, format=None):
731+
def __call__(
732+
self, progress: ProgressBarMixinBase, data: Data, format=None
733+
):
727734
return FormatWidgetMixin.__call__(self, progress, data, format)
728735

729736

@@ -735,7 +742,9 @@ def __init__(self, format='%(percentage)3d%%', na='N/A%%', **kwargs):
735742
FormatWidgetMixin.__init__(self, format=format, **kwargs)
736743
WidgetBase.__init__(self, format=format, **kwargs)
737744

738-
def get_format(self, progress: ProgressBar, data: Data, format=None):
745+
def get_format(
746+
self, progress: ProgressBarMixinBase, data: Data, format=None
747+
):
739748
# If percentage is not available, display N/A%
740749
percentage = data.get('percentage', base.Undefined)
741750
if not percentage and percentage != 0:
@@ -747,14 +756,21 @@ def get_format(self, progress: ProgressBar, data: Data, format=None):
747756
class SimpleProgress(FormatWidgetMixin, WidgetBase):
748757
'''Returns progress as a count of the total (e.g.: "5 of 47")'''
749758

759+
max_width_cache: dict[
760+
types.Union[str, tuple[float, float | types.Type[base.UnknownLength]]],
761+
types.Optional[int],
762+
]
763+
750764
DEFAULT_FORMAT = '%(value_s)s of %(max_value_s)s'
751765

752766
def __init__(self, format=DEFAULT_FORMAT, **kwargs):
753767
FormatWidgetMixin.__init__(self, format=format, **kwargs)
754768
WidgetBase.__init__(self, format=format, **kwargs)
755769
self.max_width_cache = dict(default=self.max_width)
756770

757-
def __call__(self, progress: ProgressBar, data: Data, format=None):
771+
def __call__(
772+
self, progress: ProgressBarMixinBase, data: Data, format=None
773+
):
758774
# If max_value is not available, display N/A
759775
if data.get('max_value'):
760776
data['max_value_s'] = data.get('max_value')
@@ -773,7 +789,9 @@ def __call__(self, progress: ProgressBar, data: Data, format=None):
773789

774790
# Guess the maximum width from the min and max value
775791
key = progress.min_value, progress.max_value
776-
max_width = self.max_width_cache.get(key, self.max_width)
792+
max_width: types.Optional[int] = self.max_width_cache.get(
793+
key, self.max_width
794+
)
777795
if not max_width:
778796
temporary_data = data.copy()
779797
for value in key:
@@ -832,7 +850,7 @@ def __init__(
832850

833851
def __call__(
834852
self,
835-
progress: ProgressBar,
853+
progress: ProgressBarMixinBase,
836854
data: Data,
837855
width: int = 0,
838856
):
@@ -893,7 +911,7 @@ class BouncingBar(Bar, TimeSensitiveWidgetBase):
893911

894912
def __call__(
895913
self,
896-
progress: ProgressBar,
914+
progress: ProgressBarMixinBase,
897915
data: Data,
898916
width: int = 0,
899917
):
@@ -944,7 +962,7 @@ def update_mapping(self, **mapping: types.Dict[str, types.Any]):
944962

945963
def __call__(
946964
self,
947-
progress: ProgressBar,
965+
progress: ProgressBarMixinBase,
948966
data: Data,
949967
format: types.Optional[str] = None,
950968
):
@@ -984,12 +1002,12 @@ def __init__(self, name, markers, **kwargs):
9841002
Bar.__init__(self, **kwargs)
9851003
self.markers = [string_or_lambda(marker) for marker in markers]
9861004

987-
def get_values(self, progress: ProgressBar, data: Data):
1005+
def get_values(self, progress: ProgressBarMixinBase, data: Data):
9881006
return data['variables'][self.name] or []
9891007

9901008
def __call__(
9911009
self,
992-
progress: ProgressBar,
1010+
progress: ProgressBarMixinBase,
9931011
data: Data,
9941012
width: int = 0,
9951013
):
@@ -1038,8 +1056,8 @@ def __init__(
10381056
**kwargs,
10391057
)
10401058

1041-
def get_values(self, progress: ProgressBar, data: Data):
1042-
ranges = [0] * len(self.markers)
1059+
def get_values(self, progress: ProgressBarMixinBase, data: Data):
1060+
ranges = [0.0] * len(self.markers)
10431061
for value in data['variables'][self.name] or []:
10441062
if not isinstance(value, (int, float)):
10451063
# Progress is (value, max)
@@ -1113,19 +1131,22 @@ def __init__(
11131131

11141132
def __call__(
11151133
self,
1116-
progress: ProgressBar,
1134+
progress: ProgressBarMixinBase,
11171135
data: Data,
11181136
width: int = 0,
11191137
):
11201138
left = converters.to_unicode(self.left(progress, data, width))
11211139
right = converters.to_unicode(self.right(progress, data, width))
11221140
width -= progress.custom_len(left) + progress.custom_len(right)
11231141

1142+
max_value = progress.max_value
1143+
# mypy doesn't get that the first part of the if statement makes sure
1144+
# we get the correct type
11241145
if (
1125-
progress.max_value is not base.UnknownLength
1126-
and progress.max_value > 0
1146+
max_value is not base.UnknownLength
1147+
and max_value > 0 # type: ignore
11271148
):
1128-
percent = progress.value / progress.max_value
1149+
percent = progress.value / max_value # type: ignore
11291150
else:
11301151
percent = 0
11311152

@@ -1155,7 +1176,7 @@ def __init__(self, format, **kwargs):
11551176

11561177
def __call__( # type: ignore
11571178
self,
1158-
progress: ProgressBar,
1179+
progress: ProgressBarMixinBase,
11591180
data: Data,
11601181
width: int = 0,
11611182
format: FormatString = None,
@@ -1181,7 +1202,7 @@ def __init__(self, format='%(percentage)2d%%', na='N/A%%', **kwargs):
11811202

11821203
def __call__( # type: ignore
11831204
self,
1184-
progress: ProgressBar,
1205+
progress: ProgressBarMixinBase,
11851206
data: Data,
11861207
width: int = 0,
11871208
format: FormatString = None,
@@ -1209,7 +1230,7 @@ def __init__(
12091230

12101231
def __call__(
12111232
self,
1212-
progress: ProgressBar,
1233+
progress: ProgressBarMixinBase,
12131234
data: Data,
12141235
format: types.Optional[str] = None,
12151236
):
@@ -1260,7 +1281,7 @@ def __init__(
12601281

12611282
def __call__(
12621283
self,
1263-
progress: ProgressBar,
1284+
progress: ProgressBarMixinBase,
12641285
data: Data,
12651286
format: types.Optional[str] = None,
12661287
):

0 commit comments

Comments
 (0)