Skip to content

Commit f92e628

Browse files
authored
proxy, bufio: fix client network error is summarized as backend network error (#1051)
1 parent 6894c0c commit f92e628

7 files changed

Lines changed: 123 additions & 15 deletions

File tree

pkg/proxy/backend/backend_conn_mgr.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) (
356356
mgr.processLock.Lock()
357357
defer func() {
358358
if err != nil && !pnet.IsMySQLError(err) {
359-
mgr.setQuitSourceByErr(err)
359+
mgr.SetQuitSourceByErr(err)
360360
}
361361
mgr.handshakeHandler.OnTraffic(mgr)
362362
now := time.Now()
@@ -864,7 +864,7 @@ func (mgr *BackendConnManager) setKeepAlive() {
864864
}
865865
}
866866

867-
func (mgr *BackendConnManager) setQuitSourceByErr(err error) {
867+
func (mgr *BackendConnManager) SetQuitSourceByErr(err error) {
868868
if err == nil {
869869
return
870870
}

pkg/proxy/client/client_conn.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ func (cc *ClientConnection) processMsg(ctx context.Context) error {
7272
cc.pkt.ResetSequence()
7373
clientPkt, err := cc.pkt.ReadPacket()
7474
if err != nil {
75+
cc.connMgr.SetQuitSourceByErr(err)
7576
return err
7677
}
7778
err = cc.connMgr.ExecuteCmd(ctx, clientPkt)

pkg/proxy/net/error.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
var (
1717
ErrReadConn = errors.New("failed to read the connection")
1818
ErrWriteConn = errors.New("failed to write the connection")
19-
ErrRelayConn = errors.New("failed to relay the connection")
2019
ErrFlushConn = errors.New("failed to flush the connection")
2120
ErrCloseConn = errors.New("failed to close the connection")
2221
ErrHandshakeTLS = errors.New("failed to complete tls handshake")

pkg/proxy/net/packetio.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ func (p *packetIO) ForwardUntil(destIO PacketIO, isEnd func(firstByte byte, firs
396396
return p.wrapErr(errors.Wrap(err, ErrReadConn))
397397
}
398398
if err := destIO.WritePacket(data, false); err != nil {
399-
return p.wrapErr(errors.Wrap(err, ErrWriteConn))
399+
err = errors.Wrap(err, ErrWriteConn)
400+
if dest != nil {
401+
err = dest.wrapErr(err)
402+
}
403+
return err
400404
}
401405
} else {
402406
for {
@@ -409,7 +413,10 @@ func (p *packetIO) ForwardUntil(destIO PacketIO, isEnd func(firstByte byte, firs
409413
dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1)
410414
p.limitReader.N = int64(length + 4)
411415
if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil {
412-
return p.wrapErr(errors.Wrap(err, ErrRelayConn))
416+
if errors.Is(err, bufio.ErrWriteFail) {
417+
return dest.wrapErr(errors.Wrap(err, ErrWriteConn))
418+
}
419+
return p.wrapErr(errors.Wrap(err, ErrReadConn))
413420
}
414421
p.inPackets++
415422
dest.outPackets++

pkg/proxy/net/packetio_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/go-mysql-org/go-mysql/mysql"
1313
"github.com/pingcap/tiproxy/lib/config"
14+
"github.com/pingcap/tiproxy/lib/util/errors"
1415
"github.com/pingcap/tiproxy/lib/util/logger"
1516
"github.com/pingcap/tiproxy/lib/util/security"
1617
"github.com/pingcap/tiproxy/lib/util/waitgroup"
@@ -558,6 +559,48 @@ func TestForwardUntilLongData(t *testing.T) {
558559
wg.Wait()
559560
}
560561

562+
func TestForwardUntilError(t *testing.T) {
563+
srvCh := make(chan *packetIO)
564+
var wg waitgroup.WaitGroup
565+
selfErr, peerErr := errors.New("self"), errors.New("peer")
566+
// client1 writes to server1
567+
// server1 forwards to server2
568+
// server2 writes to client2 while client2 closes
569+
wg.Run(func() {
570+
testTCPConn(t,
571+
func(t *testing.T, cli *packetIO) {
572+
data := make([]byte, DefaultConnBufferSize*2)
573+
data[0] = byte(0)
574+
require.NoError(t, cli.WritePacket(data, true))
575+
},
576+
func(t *testing.T, srv1 *packetIO) {
577+
srv1.ApplyOpts(WithWrapError(selfErr))
578+
srv2 := <-srvCh
579+
err := srv1.ForwardUntil(srv2, func(firstByte byte, firstPktLen int) (bool, bool) {
580+
return true, true
581+
}, func(response []byte) error {
582+
return srv2.Flush()
583+
})
584+
require.ErrorIs(t, err, peerErr)
585+
},
586+
1,
587+
)
588+
})
589+
wg.Run(func() {
590+
testTCPConn(t,
591+
func(t *testing.T, cli *packetIO) {
592+
require.NoError(t, cli.Close())
593+
},
594+
func(t *testing.T, srv2 *packetIO) {
595+
srv2.ApplyOpts(WithWrapError(peerErr))
596+
srvCh <- srv2
597+
},
598+
1,
599+
)
600+
})
601+
wg.Wait()
602+
}
603+
561604
func BenchmarkWritePacket(b *testing.B) {
562605
b.ReportAllocs()
563606
cli, srv := net.Pipe()

pkg/util/bufio/bufio.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
// Use of this source code is governed by a BSD-style
66
// license that can be found in the LICENSE file.
77

8-
// Package bufio is a simplified Golang bufio.
9-
// It mainly avoids calling TCPConn.ReadFrom in Writer.ReadFrom because
10-
// TCPConn.ReadFrom calls sys calls.
8+
// Package bufio is a simplified Golang bufio. It mainly changes:
9+
// - Avoid calling TCPConn.ReadFrom in Writer.ReadFrom because TCPConn.ReadFrom calls sys calls.
10+
// - Wrap the error with ErrWriteFail/ErrReadFail so that the caller knows which peer fails.
1111
package bufio
1212

1313
import (
14-
"errors"
1514
"io"
15+
16+
"github.com/pingcap/tiproxy/lib/util/errors"
1617
)
1718

1819
const (
@@ -22,6 +23,9 @@ const (
2223
var (
2324
ErrBufferFull = errors.New("bufio: buffer full")
2425
ErrNegativeCount = errors.New("bufio: negative count")
26+
27+
ErrReadFail = errors.New("read failed")
28+
ErrWriteFail = errors.New("write failed")
2529
)
2630

2731
// Buffered input.
@@ -250,9 +254,12 @@ func (b *Reader) Buffered() int { return b.w - b.r }
250254
// This may make multiple calls to the Read method of the underlying Reader.
251255
// If the underlying reader supports the WriteTo method,
252256
// this calls the underlying WriteTo without buffering.
257+
//
258+
// Wrap the error with ErrWriteFail/ErrReadFail so that the caller knows which peer fails.
253259
func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
254260
n, err = b.writeBuf(w)
255261
if err != nil {
262+
err = errors.Wrap(err, ErrWriteFail)
256263
return
257264
}
258265

@@ -280,7 +287,7 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
280287
b.err = nil
281288
}
282289

283-
return n, b.readErr()
290+
return n, errors.Wrap(b.readErr(), ErrReadFail)
284291
}
285292

286293
var errNegativeWrite = errors.New("bufio: writer returned negative count from Write")
@@ -430,15 +437,17 @@ func (b *Writer) Write(p []byte) (nn int, err error) {
430437
// supports the ReadFrom method, this calls the underlying ReadFrom.
431438
// If there is buffered data and an underlying ReadFrom, this fills
432439
// the buffer and writes it before calling ReadFrom.
440+
//
441+
// Wrap the error with ErrWriteFail/ErrReadFail so that the caller knows which peer fails.
433442
func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
434443
if b.err != nil {
435-
return 0, b.err
444+
return 0, errors.Wrap(b.err, ErrWriteFail)
436445
}
437446
var m int
438447
for {
439448
if b.Available() == 0 {
440449
if err1 := b.Flush(); err1 != nil {
441-
return n, err1
450+
return n, errors.Wrap(err1, ErrWriteFail)
442451
}
443452
}
444453
nr := 0
@@ -450,7 +459,7 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
450459
nr++
451460
}
452461
if nr == maxConsecutiveEmptyReads {
453-
return n, io.ErrNoProgress
462+
return n, errors.Wrap(io.ErrNoProgress, ErrReadFail)
454463
}
455464
b.n += m
456465
n += int64(m)
@@ -461,12 +470,14 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
461470
if err == io.EOF {
462471
// If we filled the buffer exactly, flush preemptively.
463472
if b.Available() == 0 {
464-
err = b.Flush()
473+
if err = b.Flush(); err != nil {
474+
err = errors.Wrap(err, ErrWriteFail)
475+
}
465476
} else {
466477
err = nil
467478
}
468479
}
469-
return n, err
480+
return n, errors.Wrap(err, ErrReadFail)
470481
}
471482

472483
// buffered input and output

pkg/util/bufio/bufio_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2025 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package bufio
5+
6+
import (
7+
"bytes"
8+
"io"
9+
"net"
10+
"testing"
11+
12+
"github.com/pingcap/tiproxy/pkg/util/waitgroup"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestReadFrom(t *testing.T) {
17+
// The reader breaks and ReadFrom returns ErrReadFail.
18+
var wg waitgroup.WaitGroup
19+
p1, p2 := net.Pipe()
20+
writer := bytes.NewBuffer(make([]byte, 10))
21+
bufWriter := NewWriter(writer)
22+
wg.Run(func() {
23+
_, err := p1.Write([]byte("h"))
24+
require.NoError(t, err)
25+
require.NoError(t, p2.Close())
26+
require.NoError(t, p1.Close())
27+
}, nil)
28+
wg.Run(func() {
29+
n, err := bufWriter.ReadFrom(io.LimitReader(p2, 1))
30+
require.NoError(t, err)
31+
require.EqualValues(t, 1, n)
32+
_, err = bufWriter.ReadFrom(p2)
33+
require.ErrorIs(t, err, ErrReadFail)
34+
}, nil)
35+
wg.Wait()
36+
37+
// The writer breaks and ReadFrom returns ErrWriteFail.
38+
p1, p2 = net.Pipe()
39+
bufWriter = NewWriterSize(p1, 2)
40+
reader := bytes.NewBuffer([]byte("he"))
41+
n, err := bufWriter.ReadFrom(io.LimitReader(reader, 1))
42+
require.NoError(t, err)
43+
require.EqualValues(t, 1, n)
44+
require.NoError(t, p2.Close())
45+
_, err = bufWriter.ReadFrom(io.LimitReader(reader, 1))
46+
require.ErrorIs(t, err, ErrWriteFail)
47+
}

0 commit comments

Comments
 (0)