Skip to content

Commit 0dc3f74

Browse files
committed
Return self from context managers
1 parent cc8a411 commit 0dc3f74

4 files changed

Lines changed: 70 additions & 3 deletions

File tree

NEWS.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@ API-Incompatible Changes
66

77
* Drop support for Python 3.4.
88

9+
API Additions
10+
-------------
11+
12+
* `AssertRaisesContext` and `AssertWarnsContext` now return themselves
13+
when `__enter__()` is called. By extension it now easier to call
14+
`add_test()` with `assert_raises()` et al:
15+
16+
```python
17+
with assert_raises(KeyError) as context:
18+
context.add_test(...)
19+
...
20+
```
21+
922
News in asserts 0.9.1
1023
=====================
1124

asserts/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def __init__(self, exception, msg_fmt="{msg}"):
884884
self._tests = []
885885

886886
def __enter__(self):
887-
pass
887+
return self
888888

889889
def __exit__(self, exc_type, exc_val, exc_tb):
890890
if not exc_type:
@@ -1166,6 +1166,7 @@ def __init__(self, warning_class, msg_fmt="{msg}"):
11661166
def __enter__(self):
11671167
self._warning_context = catch_warnings(record=True)
11681168
self._warnings = self._warning_context.__enter__()
1169+
return self
11691170

11701171
def __exit__(self, exc_type, exc_val, exc_tb):
11711172
self._warning_context.__exit__(exc_type, exc_val, exc_tb)

asserts/__init__.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ from typing import (
1919
)
2020

2121
_E = TypeVar("_E", bound=BaseException)
22+
_S = TypeVar("_S")
2223

2324
class AssertRaisesContext(Generic[_E]):
2425
exception: Type[_E]
2526
msg_fmt: Text
2627
def __init__(self, exception: Type[_E], msg_fmt: Text = ...) -> None: ...
27-
def __enter__(self) -> AssertRaisesContext: ...
28+
def __enter__(self: _S) -> _S: ...
2829
def __exit__(
2930
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
3031
) -> Optional[bool]: ...
@@ -41,7 +42,7 @@ class AssertRaisesRegexContext(AssertRaisesContext[_E]):
4142

4243
class AssertWarnsContext:
4344
def __init__(self, warning_class: Type[Warning], msg_fmt: Text = ...) -> None: ...
44-
def __enter__(self) -> None: ...
45+
def __enter__(self: _S) -> _S: ...
4546
def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: ...
4647
def format_message(self) -> Text: ...
4748
def add_test(self, cb: Callable[[Warning], None]) -> None: ...

test_asserts.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@
4747
)
4848

4949

50+
class Box:
51+
def __init__(self, initial_value):
52+
self.value = initial_value
53+
54+
5055
class _DummyObject(object):
5156
def __init__(self, value="x"):
5257
self.value = value
@@ -962,6 +967,29 @@ def test_assert_raises__wrong_exception_raised(self):
962967
else:
963968
fail("no exception raised")
964969

970+
def test_assert_raises__add_test_called(self):
971+
called = Box(False)
972+
973+
def extra_test(exc):
974+
assert_is_instance(exc, KeyError)
975+
called.value = True
976+
977+
with assert_raises(KeyError) as context:
978+
context.add_test(extra_test)
979+
raise KeyError()
980+
assert_true(called.value, "extra_test() was not called")
981+
982+
def test_assert_raises__add_test_not_called(self):
983+
called = Box(False)
984+
985+
def extra_test(_):
986+
called.value = True
987+
988+
with assert_raises(AssertionError):
989+
with assert_raises(KeyError) as context:
990+
context.add_test(extra_test)
991+
assert_false(called.value, "extra_test() was unexpectedly called")
992+
965993
# assert_raises_regex()
966994

967995
def test_assert_raises_regex__raises_right_exception(self):
@@ -1166,6 +1194,30 @@ def test_assert_warns__warning_handler_deinstalled_on_failure(self):
11661194
warn("bar", UserWarning)
11671195
assert_equal(1, len(warnings))
11681196

1197+
def test_assert_warns__add_test_called(self):
1198+
called = Box(False)
1199+
1200+
def extra_test(warning):
1201+
assert_is(warning.category, UserWarning)
1202+
called.value = True
1203+
return True
1204+
1205+
with assert_warns(UserWarning) as context:
1206+
context.add_test(extra_test)
1207+
warn("bar", UserWarning)
1208+
assert_true(called.value, "extra_test() was not called")
1209+
1210+
def test_assert_warns__add_test_not_called(self):
1211+
called = Box(False)
1212+
1213+
def extra_test(_):
1214+
called.value = True
1215+
1216+
with assert_raises(AssertionError):
1217+
with assert_warns(UserWarning) as context:
1218+
context.add_test(extra_test)
1219+
assert_false(called.value, "extra_test() was unexpectedly called")
1220+
11691221
# assert_warns_regex()
11701222

11711223
def test_assert_warns_regex__warned(self):

0 commit comments

Comments
 (0)