Skip to content

Commit 5d6d21a

Browse files
authored
Make Converter (including pipe) correctly infer the return type (#1331)
When running in adapter mode.
1 parent 8765288 commit 5d6d21a

3 files changed

Lines changed: 70 additions & 21 deletions

File tree

src/attr/_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PY_3_8_PLUS = sys.version_info[:2] >= (3, 8)
1414
PY_3_9_PLUS = sys.version_info[:2] >= (3, 9)
1515
PY_3_10_PLUS = sys.version_info[:2] >= (3, 10)
16+
PY_3_11_PLUS = sys.version_info[:2] >= (3, 11)
1617
PY_3_12_PLUS = sys.version_info[:2] >= (3, 12)
1718
PY_3_13_PLUS = sys.version_info[:2] >= (3, 13)
1819
PY_3_14_PLUS = sys.version_info[:2] >= (3, 14)

src/attr/_make.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ._compat import (
2323
PY_3_8_PLUS,
2424
PY_3_10_PLUS,
25+
PY_3_11_PLUS,
2526
_AnnotationExtractor,
2627
_get_annotations,
2728
get_generic_base,
@@ -2705,28 +2706,27 @@ def __init__(self, converter, *, takes_self=False, takes_field=False):
27052706
self.takes_self = takes_self
27062707
self.takes_field = takes_field
27072708

2708-
self._first_param_type = _AnnotationExtractor(
2709-
converter
2710-
).get_first_param_type()
2709+
ex = _AnnotationExtractor(converter)
2710+
self._first_param_type = ex.get_first_param_type()
27112711

2712-
self.__call__ = {
2713-
(False, False): self._takes_only_value,
2714-
(True, False): self._takes_instance,
2715-
(False, True): self._takes_field,
2716-
(True, True): self._takes_both,
2717-
}[self.takes_self, self.takes_field]
2718-
2719-
def _takes_only_value(self, value, instance, field):
2720-
return self.converter(value)
2721-
2722-
def _takes_instance(self, value, instance, field):
2723-
return self.converter(value, instance)
2724-
2725-
def _takes_field(self, value, instance, field):
2726-
return self.converter(value, field)
2712+
if not (self.takes_self or self.takes_field):
2713+
self.__call__ = lambda value, _, __: self.converter(value)
2714+
elif self.takes_self and not self.takes_field:
2715+
self.__call__ = lambda value, instance, __: self.converter(
2716+
value, instance
2717+
)
2718+
elif not self.takes_self and self.takes_field:
2719+
self.__call__ = lambda value, __, field: self.converter(
2720+
value, field
2721+
)
2722+
else:
2723+
self.__call__ = lambda value, instance, field: self.converter(
2724+
value, instance, field
2725+
)
27272726

2728-
def _takes_both(self, value, instance, field):
2729-
return self.converter(value, instance, field)
2727+
rt = ex.get_return_type()
2728+
if rt is not None:
2729+
self.__call__.__annotations__["return"] = rt
27302730

27312731
@staticmethod
27322732
def _get_global_name(attr_name: str) -> str:
@@ -2948,8 +2948,12 @@ def pipe_converter(val, inst, field):
29482948
if t:
29492949
pipe_converter.__annotations__["val"] = t
29502950

2951+
last = converters[-1]
2952+
if not PY_3_11_PLUS and isinstance(last, Converter):
2953+
last = last.__call__
2954+
29512955
# Get return type from last converter.
2952-
rt = _AnnotationExtractor(converters[-1]).get_return_type()
2956+
rt = _AnnotationExtractor(last).get_return_type()
29532957
if rt:
29542958
pipe_converter.__annotations__["return"] = rt
29552959

tests/test_converters.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import attr
1212

1313
from attr import Converter, Factory, attrib
14+
from attr._compat import _AnnotationExtractor
1415
from attr.converters import default_if_none, optional, pipe, to_bool
1516

1617

@@ -91,6 +92,26 @@ def save_args(*args):
9192

9293
assert (42, instance, field) == taken
9394

95+
def test_annotations_if_last_in_pipe(self):
96+
"""
97+
If the wrapped converter has annotations, they are copied to the
98+
Converter __call__.
99+
"""
100+
101+
def wrapped(_, __, ___) -> float:
102+
pass
103+
104+
c = Converter(wrapped)
105+
106+
assert float is c.__call__.__annotations__["return"]
107+
108+
# Doesn't overwrite globally.
109+
110+
c2 = Converter(int)
111+
112+
assert float is c.__call__.__annotations__["return"]
113+
assert None is c2.__call__.__annotations__.get("return")
114+
94115

95116
class TestOptional:
96117
"""
@@ -227,6 +248,29 @@ def test_empty(self):
227248

228249
assert o is pipe().converter(o, None, None)
229250

251+
def test_wrapped_annotation(self):
252+
"""
253+
The return type of the wrapped converter is copied into its __call__
254+
and ultimately into pipe's wrapped converter.
255+
"""
256+
257+
def last(value) -> bool:
258+
return bool(value)
259+
260+
@attr.s
261+
class C:
262+
x = attr.ib(converter=[Converter(int), Converter(last)])
263+
264+
i = C(5)
265+
266+
assert True is i.x
267+
assert (
268+
bool
269+
is _AnnotationExtractor(
270+
attr.fields(C).x.converter.__call__
271+
).get_return_type()
272+
)
273+
230274

231275
class TestToBool:
232276
def test_unhashable(self):

0 commit comments

Comments
 (0)