Skip to content

Commit f91c163

Browse files
authored
Make BufferedWriter generic over a protocol (#14015)
1 parent 44dac88 commit f91c163

2 files changed

Lines changed: 46 additions & 8 deletions

File tree

stdlib/@tests/test_cases/check_io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
from _io import BufferedReader
1+
from _io import BufferedReader, BufferedRWPair, BufferedWriter
22
from gzip import GzipFile
33
from io import FileIO, RawIOBase, TextIOWrapper
4+
from socket import SocketIO
5+
from typing import Any
46
from typing_extensions import assert_type
57

8+
socket: Any = None
9+
610
BufferedReader(RawIOBase())
11+
BufferedWriter(RawIOBase())
12+
BufferedWriter(SocketIO(socket, "r"))
13+
14+
BufferedRWPair(open("", "rb"), open("", "wb"))
715

816
assert_type(TextIOWrapper(FileIO("")).buffer, FileIO)
917
assert_type(TextIOWrapper(FileIO(13)).detach(), FileIO)

stdlib/_io.pyi

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,39 @@ class BufferedReader(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_Buffere
163163
def seek(self, target: int, whence: int = 0, /) -> int: ...
164164
def truncate(self, pos: int | None = None, /) -> int: ...
165165

166+
@type_check_only
167+
class _BufferedWriterStream(Protocol):
168+
def write(self, b: WriteableBuffer, /) -> int | None: ...
169+
def seek(self, pos: int, whence: int, /) -> int: ...
170+
def tell(self) -> int: ...
171+
def truncate(self, size: int, /) -> int: ...
172+
def flush(self) -> object: ...
173+
def close(self) -> object: ...
174+
@property
175+
def closed(self) -> bool: ...
176+
def writable(self) -> bool: ...
177+
def seekable(self) -> bool: ...
178+
179+
# The following methods just pass through to the underlying stream. Since
180+
# not all streams support them, they are marked as optional here, and will
181+
# raise an AttributeError if called on a stream that does not support them.
182+
183+
# @property
184+
# def name(self) -> Any: ... # Type is inconsistent between the various I/O types.
185+
# @property
186+
# def mode(self) -> str: ...
187+
# def fileno(self) -> int: ...
188+
# def isatty(self) -> bool: ...
189+
190+
_BufferedWriterStreamT = TypeVar("_BufferedWriterStreamT", bound=_BufferedWriterStream, default=_BufferedWriterStream)
191+
166192
@disjoint_base
167-
class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
168-
raw: RawIOBase
193+
class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_BufferedWriterStreamT]): # type: ignore[misc] # incompatible definitions of writelines in the base classes
194+
raw: _BufferedWriterStreamT
169195
if sys.version_info >= (3, 14):
170-
def __init__(self, raw: RawIOBase, buffer_size: int = 131072) -> None: ...
196+
def __init__(self, raw: _BufferedWriterStreamT, buffer_size: int = 131072) -> None: ...
171197
else:
172-
def __init__(self, raw: RawIOBase, buffer_size: int = 8192) -> None: ...
198+
def __init__(self, raw: _BufferedWriterStreamT, buffer_size: int = 8192) -> None: ...
173199

174200
def write(self, buffer: ReadableBuffer, /) -> int: ...
175201
def seek(self, target: int, whence: int = 0, /) -> int: ...
@@ -190,11 +216,15 @@ class BufferedRandom(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore
190216
def truncate(self, pos: int | None = None, /) -> int: ...
191217

192218
@disjoint_base
193-
class BufferedRWPair(BufferedIOBase, _BufferedIOBase, Generic[_BufferedReaderStreamT]):
219+
class BufferedRWPair(BufferedIOBase, _BufferedIOBase, Generic[_BufferedReaderStreamT, _BufferedWriterStreamT]):
194220
if sys.version_info >= (3, 14):
195-
def __init__(self, reader: _BufferedReaderStreamT, writer: RawIOBase, buffer_size: int = 131072, /) -> None: ...
221+
def __init__(
222+
self, reader: _BufferedReaderStreamT, writer: _BufferedWriterStreamT, buffer_size: int = 131072, /
223+
) -> None: ...
196224
else:
197-
def __init__(self, reader: _BufferedReaderStreamT, writer: RawIOBase, buffer_size: int = 8192, /) -> None: ...
225+
def __init__(
226+
self, reader: _BufferedReaderStreamT, writer: _BufferedWriterStreamT, buffer_size: int = 8192, /
227+
) -> None: ...
198228

199229
def peek(self, size: int = 0, /) -> bytes: ...
200230

0 commit comments

Comments
 (0)