Skip to content

Commit 1c07a17

Browse files
Refactored test assertions that have suboptimal tests with numbers (#7671)
### Description As discussed in PR #7609, I tried to break the suboptimal test refactors to smaller pieces. Suboptimal Assert: Instead of using statements such as assertIsNone, assertIsInstance, always simply use assertTrue or assertFalse. This will decrease the code overall readability and increase the execution time as extra logic needed. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han Wang <freddie.wanah@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent dc58e5c commit 1c07a17

14 files changed

Lines changed: 30 additions & 28 deletions

tests/test_bundle_get_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_get_all_bundles_list(self, params):
5353
output = get_all_bundles_list(**params)
5454
self.assertTrue(isinstance(output, list))
5555
self.assertTrue(isinstance(output[0], tuple))
56-
self.assertTrue(len(output[0]) == 2)
56+
self.assertEqual(len(output[0]), 2)
5757

5858
@parameterized.expand([TEST_CASE_1, TEST_CASE_5])
5959
@skip_if_quick

tests/test_compute_regression_metrics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,24 @@ def test_shape_reduction(self):
7070
mt = mt_fn(reduction="mean")
7171
mt(in_tensor, in_tensor)
7272
out_tensor = mt.aggregate()
73-
self.assertTrue(len(out_tensor.shape) == 1)
73+
self.assertEqual(len(out_tensor.shape), 1)
7474

7575
mt = mt_fn(reduction="sum")
7676
mt(in_tensor, in_tensor)
7777
out_tensor = mt.aggregate()
78-
self.assertTrue(len(out_tensor.shape) == 0)
78+
self.assertEqual(len(out_tensor.shape), 0)
7979

8080
mt = mt_fn(reduction="sum") # test reduction arg overriding
8181
mt(in_tensor, in_tensor)
8282
out_tensor = mt.aggregate(reduction="mean_channel")
83-
self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch)
83+
self.assertEqual(len(out_tensor.shape), 1)
84+
self.assertEqual(out_tensor.shape[0], batch)
8485

8586
mt = mt_fn(reduction="sum_channel")
8687
mt(in_tensor, in_tensor)
8788
out_tensor = mt.aggregate()
88-
self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch)
89+
self.assertEqual(len(out_tensor.shape), 1)
90+
self.assertEqual(out_tensor.shape[0], batch)
8991

9092
def test_compare_numpy(self):
9193
set_determinism(seed=123)

tests/test_handler_stats.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def _update_metric(engine):
7676
if has_key_word.match(line):
7777
content_count += 1
7878
if epoch_log is True:
79-
self.assertTrue(content_count == max_epochs)
79+
self.assertEqual(content_count, max_epochs)
8080
else:
81-
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter
81+
self.assertEqual(content_count, 2) # 2 = len([1, 2]) from event_filter
8282

8383
@parameterized.expand([[True], [get_event_filter([1, 3])]])
8484
def test_loss_print(self, iteration_log):
@@ -116,9 +116,9 @@ def _train_func(engine, batch):
116116
if has_key_word.match(line):
117117
content_count += 1
118118
if iteration_log is True:
119-
self.assertTrue(content_count == num_iters * max_epochs)
119+
self.assertEqual(content_count, num_iters * max_epochs)
120120
else:
121-
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter
121+
self.assertEqual(content_count, 2) # 2 = len([1, 3]) from event_filter
122122

123123
def test_loss_dict(self):
124124
log_stream = StringIO()
@@ -150,7 +150,7 @@ def _train_func(engine, batch):
150150
for line in output_str.split("\n"):
151151
if has_key_word.match(line):
152152
content_count += 1
153-
self.assertTrue(content_count > 0)
153+
self.assertGreater(content_count, 0)
154154

155155
def test_loss_file(self):
156156
key_to_handler = "test_logging"
@@ -184,7 +184,7 @@ def _train_func(engine, batch):
184184
for line in output_str.split("\n"):
185185
if has_key_word.match(line):
186186
content_count += 1
187-
self.assertTrue(content_count > 0)
187+
self.assertGreater(content_count, 0)
188188

189189
def test_exception(self):
190190
# set up engine
@@ -239,7 +239,7 @@ def _update_metric(engine):
239239
for line in output_str.split("\n"):
240240
if has_key_word.match(line):
241241
content_count += 1
242-
self.assertTrue(content_count > 0)
242+
self.assertGreater(content_count, 0)
243243

244244
def test_default_logger(self):
245245
log_stream = StringIO()
@@ -274,7 +274,7 @@ def _train_func(engine, batch):
274274
for line in output_str.split("\n"):
275275
if has_key_word.match(line):
276276
content_count += 1
277-
self.assertTrue(content_count > 0)
277+
self.assertGreater(content_count, 0)
278278

279279

280280
if __name__ == "__main__":

tests/test_invertd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_invert(self):
134134
# 25300: 2 workers (cpu, non-macos)
135135
# 1812: 0 workers (gpu or macos)
136136
# 1821: windows torch 1.10.0
137-
self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}")
137+
self.assertLess((reverted.size - n_good), 40000, f"diff. {reverted.size - n_good}")
138138

139139
set_determinism(seed=None)
140140

tests/test_load_spacing_orientation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_load_spacingd(self, filename):
4848
ref = resample_to_output(anat, (1, 0.2, 1), order=1)
4949
t2 = time.time()
5050
print(f"time scipy: {t2 - t1}")
51-
self.assertTrue(t2 >= t1)
51+
self.assertGreaterEqual(t2, t1)
5252
np.testing.assert_allclose(res_dict["image"].affine, ref.affine)
5353
np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape)
5454
np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05)
@@ -68,7 +68,7 @@ def test_load_spacingd_rotate(self, filename):
6868
ref = resample_to_output(anat, (1, 2, 3), order=1)
6969
t2 = time.time()
7070
print(f"time scipy: {t2 - t1}")
71-
self.assertTrue(t2 >= t1)
71+
self.assertGreaterEqual(t2, t1)
7272
np.testing.assert_allclose(res_dict["image"].affine, ref.affine)
7373
if "anatomical" not in filename:
7474
np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape)

tests/test_meta_affine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_linear_consistent(self, xform_cls, input_dict, atol):
160160
diff = np.abs(itk.GetArrayFromImage(ref_2) - itk.GetArrayFromImage(expected))
161161
avg_diff = np.mean(diff)
162162

163-
self.assertTrue(avg_diff < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
163+
self.assertLess(avg_diff, atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
164164

165165
@parameterized.expand(TEST_CASES_DICT)
166166
def test_linear_consistent_dict(self, xform_cls, input_dict, atol):
@@ -175,7 +175,7 @@ def test_linear_consistent_dict(self, xform_cls, input_dict, atol):
175175
diff = {k: np.abs(itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys}
176176
avg_diff = {k: np.mean(diff[k]) for k in keys}
177177
for k in keys:
178-
self.assertTrue(avg_diff[k] < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
178+
self.assertLess(avg_diff[k], atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
179179

180180

181181
if __name__ == "__main__":

tests/test_persistentdataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_different_transforms(self):
165165
im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0]
166166
im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0]
167167
l2 = ((im1 - im2) ** 2).sum() ** 0.5
168-
self.assertTrue(l2 > 1)
168+
self.assertGreater(l2, 1)
169169

170170

171171
if __name__ == "__main__":

tests/test_rand_weighted_cropd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, e
154154
crop = RandWeightedCropd(**init_params)
155155
crop.set_random_state(10)
156156
result = crop(input_data)
157-
self.assertTrue(len(result) == init_params["num_samples"])
157+
self.assertEqual(len(result), init_params["num_samples"])
158158
_len = len(tuple(input_data.keys()))
159159
self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys()))
160160

tests/test_recon_net_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_reshape_channel_complex(self, test_data):
6464
def test_complex_normalize(self, test_data):
6565
result, mean, std = complex_normalize(test_data)
6666
result = result * std + mean
67-
self.assertTrue((((result - test_data) ** 2).mean() ** 0.5).item() < 1e-5)
67+
self.assertLess((((result - test_data) ** 2).mean() ** 0.5).item(), 1e-5)
6868

6969
@parameterized.expand(TEST_PAD)
7070
def test_pad(self, test_data):

tests/test_reg_loss_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def forward(self, x):
9999
# backward pass
100100
loss_val.backward()
101101
optimizer.step()
102-
self.assertTrue(init_loss > loss_val, "loss did not decrease")
102+
self.assertGreater(init_loss, loss_val, "loss did not decrease")
103103

104104

105105
if __name__ == "__main__":

0 commit comments

Comments
 (0)