Skip to content

Commit c489de0

Browse files
Mypy check untyped functions (#5673)
Fixes #5657. ### Description Enable config to also type check untyped functions (not in tests). - Type annotations were added - Explicit casts were added - Variable redefinitions with other types were refactored - If no other option was available (or it would be overly verbose) `# type: ignore` was added ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Felix Schnabel <f.schnabel@tum.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4f584a9 commit c489de0

49 files changed

Lines changed: 248 additions & 209 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

monai/_extensions/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def timeout(time, message):
3030
try:
3131
timer = Timer(time, interrupt_main)
3232
timer.daemon = True
33-
yield timer.start()
33+
timer.start()
34+
yield
3435
except KeyboardInterrupt as e:
3536
if timer is not None and timer.is_alive():
3637
raise e # interrupt from user?

monai/apps/auto3dseg/auto_runner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import subprocess
1515
from copy import deepcopy
1616
from time import sleep
17-
from typing import Any, Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union, cast
1818

1919
import numpy as np
2020
import torch
@@ -204,6 +204,8 @@ class AutoRunner:
204204
205205
"""
206206

207+
analyze_params: Optional[Dict]
208+
207209
def __init__(
208210
self,
209211
work_dir: str = "./work_dir",
@@ -561,7 +563,7 @@ def _train_algo_in_nni(self, history):
561563
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
562564
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)
563565

564-
max_trial = min(self.hpo_tasks, default_nni_config["maxTrialNumber"])
566+
max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
565567
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"
566568

567569
if mode_dry_run:
@@ -585,7 +587,7 @@ def run(self):
585587
Run the AutoRunner pipeline
586588
"""
587589
# step 1: data analysis
588-
if self.analyze:
590+
if self.analyze and self.analyze_params is not None:
589591
logger.info("Running data analysis...")
590592
da = DataAnalyzer(
591593
self.datalist_filename, self.dataroot, output_path=self.datastats_filename, **self.analyze_params

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ def infer(self, image_file):
246246
configs_path = [os.path.join(config_dir, f) for f in os.listdir(config_dir)]
247247

248248
spec = importlib.util.spec_from_file_location("InferClass", infer_py)
249-
infer_class = importlib.util.module_from_spec(spec)
249+
infer_class = importlib.util.module_from_spec(spec) # type: ignore
250250
sys.modules["InferClass"] = infer_class
251-
spec.loader.exec_module(infer_class)
251+
spec.loader.exec_module(infer_class) # type: ignore
252252
return infer_class.InferClass(configs_path, *args, **kwargs)
253253

254254
def predict(self, predict_files: list, predict_params=None):

monai/apps/auto3dseg/data_analyzer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import warnings
1313
from os import path
14-
from typing import Dict, List, Optional, Union
14+
from typing import Any, Dict, List, Optional, Union, cast
1515

1616
import numpy as np
1717
import torch
@@ -227,7 +227,7 @@ def get_all_case_stats(self, key="training", transform_list=None):
227227
files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)
228228
dataset = Dataset(data=files, transform=transform)
229229
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation)
230-
result = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
230+
result: Dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
231231

232232
if not has_tqdm:
233233
warnings.warn("tqdm is not installed. not displaying the caching progress.")
@@ -261,7 +261,7 @@ def get_all_case_stats(self, key="training", transform_list=None):
261261
)
262262
result[DataStatsKeys.BY_CASE].append(stats_by_cases)
263263

264-
result[DataStatsKeys.SUMMARY] = summarizer.summarize(result[DataStatsKeys.BY_CASE])
264+
result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(List, result[DataStatsKeys.BY_CASE]))
265265

266266
if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):
267267
print("Data spacing is not completely uniform. MONAI transforms may provide unexpected result")

monai/apps/auto3dseg/hpo_gen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from abc import abstractmethod
1414
from copy import deepcopy
15-
from typing import Optional
15+
from typing import Optional, cast
1616
from warnings import warn
1717

1818
from monai.apps.auto3dseg.bundle_gen import BundleAlgo
@@ -147,7 +147,8 @@ def print_bundle_algo_instruction(self):
147147
logger.info("-" * 140)
148148
logger.info("If NNI will run in a remote env: ")
149149
logger.info(
150-
f"1. Copy the algorithm_templates folder {self.algo.template_path} to remote {{remote_algorithm_templates_dir}}"
150+
f"1. Copy the algorithm_templates folder {cast(BundleAlgo, self.algo).template_path} "
151+
f"to remote {{remote_algorithm_templates_dir}}"
151152
)
152153
logger.info(f"2. Copy the older {self.algo.get_output_path()} to the remote machine {{remote_algo_dir}}")
153154
logger.info("Then add the following line to the trialCommand in your NNI config: ")

monai/apps/deepedit/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ def __init__(
503503
self.discrepancy = discrepancy
504504
self.probability = probability
505505
self._will_interact = None
506-
self.is_pos = None
507-
self.is_other = None
506+
self.is_pos: Optional[bool] = None
507+
self.is_other: Optional[bool] = None
508508
self.default_guidance = None
509509
self.guidance: Dict[str, List[List[int]]] = {}
510510

monai/apps/deepgrow/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import logging
1313
import os
14-
from typing import Dict, List
14+
from typing import Dict, List, Union
1515

1616
import numpy as np
1717

@@ -144,7 +144,7 @@ def _default_transforms(image_key, label_key, pixdim):
144144

145145

146146
def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
147-
data_list = []
147+
data_list: List[Dict[str, Union[str, int]]] = []
148148

149149
image_count = 0
150150
label_count = 0
@@ -211,7 +211,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
211211

212212

213213
def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
214-
data_list = []
214+
data_list: List[Dict[str, Union[str, int]]] = []
215215

216216
image_count = 0
217217
label_count = 0

monai/apps/deepgrow/transforms.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111
import json
12-
from typing import Callable, Dict, Hashable, Optional, Sequence, Union
12+
from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union
1313

1414
import numpy as np
1515
import torch
@@ -437,8 +437,8 @@ def __call__(self, data):
437437

438438
if np.all(np.less(current_size, self.spatial_size)):
439439
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
440-
box_start = np.array([s.start for s in cropper.slices])
441-
box_end = np.array([s.stop for s in cropper.slices])
440+
box_start = np.array([s.start for s in cropper.slices]) # type: ignore
441+
box_end = np.array([s.stop for s in cropper.slices]) # type: ignore
442442
else:
443443
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
444444

@@ -523,11 +523,10 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num):
523523
pos = neg = []
524524

525525
if self.dimensions == 2:
526-
points = list(pos_clicks)
526+
points: List = list(pos_clicks)
527527
points.extend(neg_clicks)
528-
points = np.array(points)
529528

530-
slices = list(np.unique(points[:, self.axis]))
529+
slices = list(np.unique(np.array(points)[:, self.axis]))
531530
slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num)
532531

533532
if len(pos_clicks):
@@ -938,8 +937,7 @@ def _apply(self, image, guidance):
938937
for i, size_i in enumerate(image.shape):
939938
idx.append(slice_idx) if i == self.axis else idx.append(slice(0, size_i))
940939

941-
idx = tuple(idx)
942-
return image[idx], idx
940+
return image[tuple(idx)], tuple(idx)
943941

944942
def __call__(self, data):
945943
d = dict(data)

monai/apps/nuclick/transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616
import torch
1717

18-
from monai.config import KeysCollection
18+
from monai.config import KeysCollection, NdarrayOrTensor
1919
from monai.networks.layers import GaussianFilter
2020
from monai.transforms import MapTransform, Randomizable, SpatialPad
2121
from monai.utils import StrEnum, convert_to_numpy, optional_import
@@ -337,14 +337,15 @@ def _apply_gaussion(self, t):
337337

338338
def _seed_point(self, label):
339339
if distance_transform_cdt is None or not self.use_distance:
340+
indices: NdarrayOrTensor
340341
if hasattr(torch, "argwhere"):
341342
indices = torch.argwhere(label > 0)
342343
else:
343344
indices = np.argwhere(convert_to_numpy(label) > 0)
344345

345346
if len(indices) > 0:
346-
idx = self.R.randint(0, len(indices))
347-
return indices[idx, 0], indices[idx, 1]
347+
index = self.R.randint(0, len(indices))
348+
return indices[index, 0], indices[index, 1]
348349
return None
349350

350351
distance = distance_transform_cdt(label).flatten()

monai/apps/pathology/data/datasets.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import os
1313
import sys
14-
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
14+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
1515

1616
import numpy as np
1717

@@ -62,14 +62,14 @@ def __init__(
6262
):
6363
super().__init__(data, transform)
6464

65-
self.region_size = ensure_tuple_rep(region_size, 2)
66-
self.grid_shape = ensure_tuple_rep(grid_shape, 2)
67-
self.patch_size = ensure_tuple_rep(patch_size, 2)
65+
self.region_size = cast(Tuple[int, int], ensure_tuple_rep(region_size, 2))
66+
self.grid_shape = cast(Tuple[int, int], ensure_tuple_rep(grid_shape, 2))
67+
self.patch_size = cast(Tuple[int, int], ensure_tuple_rep(patch_size, 2))
6868

6969
self.image_path_list = list({x["image"] for x in self.data})
7070
self.image_reader_name = image_reader_name.lower()
7171
self.image_reader = WSIReader(backend=image_reader_name, **kwargs)
72-
self.wsi_object_dict = None
72+
self.wsi_object_dict: Optional[Dict] = None
7373
if self.image_reader_name != "openslide":
7474
# OpenSlide causes memory issue if we prefetch image objects
7575
self._fetch_wsi_objects()
@@ -85,11 +85,11 @@ def __getitem__(self, index):
8585
if self.image_reader_name == "openslide":
8686
img_obj = self.image_reader.read(sample["image"])
8787
else:
88-
img_obj = self.wsi_object_dict[sample["image"]]
88+
img_obj = cast(Dict, self.wsi_object_dict)[sample["image"]]
8989
location = [sample["location"][i] - self.region_size[i] // 2 for i in range(len(self.region_size))]
9090
images, _ = self.image_reader.get_data(
9191
img=img_obj,
92-
location=location,
92+
location=cast(Tuple[int, int], location),
9393
size=self.region_size,
9494
grid_shape=self.grid_shape,
9595
patch_size=self.patch_size,
@@ -209,7 +209,7 @@ def __init__(
209209
) -> None:
210210
super().__init__(data, transform)
211211

212-
self.patch_size = ensure_tuple_rep(patch_size, 2)
212+
self.patch_size = cast(Tuple[int, int], ensure_tuple_rep(patch_size, 2))
213213

214214
# set up whole slide image reader
215215
self.image_reader_name = image_reader_name.lower()
@@ -309,7 +309,7 @@ def _load_a_patch(self, index):
309309
this method, first, finds the whole slide image and the patch that should be extracted,
310310
then it loads the patch and provide it with its image name and the corresponding mask location.
311311
"""
312-
sample_num = np.argmax(self.cum_num_patches > index) - 1
312+
sample_num = cast(int, np.argmax(self.cum_num_patches > index)) - 1
313313
sample = self.data[sample_num]
314314
patch_num = index - self.cum_num_patches[sample_num]
315315
location_on_image = sample["image_locations"][patch_num]

0 commit comments

Comments
 (0)