@@ -6,7 +6,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, R
66use tokio:: net:: TcpStream ;
77use tracing:: { debug, enabled, trace, Level } ;
88
9- use std:: io:: Error ;
9+ use std:: io:: { Error , ErrorKind } ;
1010use std:: net:: SocketAddr ;
1111use std:: ops:: Deref ;
1212use std:: pin:: Pin ;
@@ -108,8 +108,8 @@ impl Stream {
108108 pub async fn check ( & mut self ) -> Result < ( ) , crate :: net:: Error > {
109109 let mut buf = [ 0u8 ; 1 ] ;
110110 match self {
111- Self :: Plain ( plain) => plain. get_mut ( ) . peek ( & mut buf) . await ?,
112- Self :: Tls ( tls) => tls. get_mut ( ) . get_mut ( ) . 0 . peek ( & mut buf) . await ?,
111+ Self :: Plain ( plain) => eof ( plain. get_mut ( ) . peek ( & mut buf) . await ) ?,
112+ Self :: Tls ( tls) => eof ( tls. get_mut ( ) . get_mut ( ) . 0 . peek ( & mut buf) . await ) ?,
113113 Self :: DevNull => 0 ,
114114 } ;
115115
@@ -126,8 +126,8 @@ impl Stream {
126126 let bytes = message. to_bytes ( ) ?;
127127
128128 match self {
129- Stream :: Plain ( ref mut stream) => stream. write_all ( & bytes) . await ?,
130- Stream :: Tls ( ref mut stream) => stream. write_all ( & bytes) . await ?,
129+ Stream :: Plain ( ref mut stream) => eof ( stream. write_all ( & bytes) . await ) ?,
130+ Stream :: Tls ( ref mut stream) => eof ( stream. write_all ( & bytes) . await ) ?,
131131 Self :: DevNull => ( ) ,
132132 }
133133
@@ -165,7 +165,7 @@ impl Stream {
165165 message : & impl Protocol ,
166166 ) -> Result < usize , crate :: net:: Error > {
167167 let sent = self . send ( message) . await ?;
168- self . flush ( ) . await ?;
168+ eof ( self . flush ( ) . await ) ?;
169169 trace ! ( "😳" ) ;
170170
171171 Ok ( sent)
@@ -180,7 +180,7 @@ impl Stream {
180180 for message in messages {
181181 sent += self . send ( message) . await ?;
182182 }
183- self . flush ( ) . await ?;
183+ eof ( self . flush ( ) . await ) ?;
184184 trace ! ( "😳" ) ;
185185 Ok ( sent)
186186 }
@@ -199,15 +199,15 @@ impl Stream {
199199
200200 /// Read data into a buffer, avoiding unnecessary allocations.
201201 pub async fn read_buf ( & mut self , bytes : & mut BytesMut ) -> Result < Message , crate :: net:: Error > {
202- let code = self . read_u8 ( ) . await ?;
203- let len = self . read_i32 ( ) . await ?;
202+ let code = eof ( self . read_u8 ( ) . await ) ?;
203+ let len = eof ( self . read_i32 ( ) . await ) ?;
204204
205205 bytes. put_u8 ( code) ;
206206 bytes. put_i32 ( len) ;
207207
208208 // Length must be at least 4 bytes.
209209 if len < 4 {
210- return Err ( crate :: net:: Error :: Eof ) ;
210+ return Err ( crate :: net:: Error :: UnexpectedEof ) ;
211211 }
212212
213213 let capacity = len as usize + 1 ;
@@ -218,7 +218,7 @@ impl Stream {
218218 bytes. set_len ( capacity) ;
219219 }
220220
221- self . read_exact ( & mut bytes[ 5 ..capacity] ) . await ?;
221+ eof ( self . read_exact ( & mut bytes[ 5 ..capacity] ) . await ) ?;
222222
223223 let message = Message :: new ( bytes. split ( ) . freeze ( ) ) ;
224224
@@ -261,6 +261,19 @@ impl Stream {
261261 }
262262}
263263
264+ fn eof < T > ( result : std:: io:: Result < T > ) -> Result < T , crate :: net:: Error > {
265+ match result {
266+ Ok ( val) => Ok ( val) ,
267+ Err ( err) => {
268+ if err. kind ( ) == ErrorKind :: UnexpectedEof {
269+ Err ( crate :: net:: Error :: UnexpectedEof )
270+ } else {
271+ Err ( crate :: net:: Error :: Io ( err) )
272+ }
273+ }
274+ }
275+ }
276+
264277/// Wrapper around SocketAddr
265278/// to make it easier to debug.
266279pub struct PeerAddr {
0 commit comments