Skip to content

Commit 8c84044

Browse files
authored
Add OnConnect and OnDisconnect callbacks to ManagedConnection (#3309)
* Add OnConnect and OnDisconnect callbacks to ManagedConnection This enables consumers to react to connection state changes without polling. Callbacks are invoked when: - OnConnect: connection is established (including reconnections) - OnDisconnect: connection is lost (with error reason) * Add OnConnect and OnDisconnect callbacks to ManagedConnection This enables consumers to react to connection state changes without polling. Callbacks are invoked when: - OnConnect: connection is established (including reconnections) - OnDisconnect: connection is lost (with error reason)
1 parent 116bf6c commit 8c84044

2 files changed

Lines changed: 140 additions & 3 deletions

File tree

websocket/connection.go

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,40 @@ type ManagedConnection struct {
8686

8787
// Used for the exponential backoff when connecting
8888
connectionBackoff wait.Backoff
89+
90+
// OnConnect is called when a connection is successfully established.
91+
// This callback is invoked each time the connection is established,
92+
// including reconnections.
93+
OnConnect func()
94+
95+
// OnDisconnect is called when a connection is lost.
96+
// The error parameter contains the reason for the disconnection.
97+
OnDisconnect func(error)
98+
}
99+
100+
// ConnectionOption is a functional option for configuring ManagedConnection.
101+
type ConnectionOption func(*ManagedConnection)
102+
103+
// WithOnConnect sets a callback that is invoked when a connection is established.
104+
func WithOnConnect(f func()) ConnectionOption {
105+
return func(c *ManagedConnection) {
106+
c.OnConnect = f
107+
}
108+
}
109+
110+
// WithOnDisconnect sets a callback that is invoked when a connection is lost.
111+
func WithOnDisconnect(f func(error)) ConnectionOption {
112+
return func(c *ManagedConnection) {
113+
c.OnDisconnect = f
114+
}
89115
}
90116

91117
// NewDurableSendingConnection creates a new websocket connection
92118
// that can only send messages to the endpoint it connects to.
93119
// The connection will continuously be kept alive and reconnected
94120
// in case of a loss of connectivity.
95-
func NewDurableSendingConnection(target string, logger *zap.SugaredLogger) *ManagedConnection {
96-
return NewDurableConnection(target, nil, logger)
121+
func NewDurableSendingConnection(target string, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection {
122+
return NewDurableConnection(target, nil, logger, opts...)
97123
}
98124

99125
// NewDurableSendingConnectionGuaranteed creates a new websocket connection
@@ -127,7 +153,7 @@ func NewDurableSendingConnectionGuaranteed(target string, duration time.Duration
127153
//
128154
// go func() {conn.Shutdown(); close(messageChan)}
129155
// go func() {for range messageChan {}}
130-
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger) *ManagedConnection {
156+
func NewDurableConnection(target string, messageChan chan []byte, logger *zap.SugaredLogger, opts ...ConnectionOption) *ManagedConnection {
131157
websocketConnectionFactory := func() (rawConnection, error) {
132158
dialer := &websocket.Dialer{
133159
// This needs to be relatively short to avoid the connection getting blackholed for a long time
@@ -149,6 +175,11 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su
149175

150176
c := newConnection(websocketConnectionFactory, messageChan)
151177

178+
// Apply options before starting the goroutine
179+
for _, opt := range opts {
180+
opt(c)
181+
}
182+
152183
// Keep the connection alive asynchronously and reconnect on
153184
// connection failure.
154185
c.processingWg.Add(1)
@@ -164,8 +195,14 @@ func NewDurableConnection(target string, messageChan chan []byte, logger *zap.Su
164195
continue
165196
}
166197
logger.Debug("Connected to ", target)
198+
if c.OnConnect != nil {
199+
c.OnConnect()
200+
}
167201
if err := c.keepalive(); err != nil {
168202
logger.Errorw(fmt.Sprintf("Connection to %s broke down, reconnecting...", target), zap.Error(err))
203+
if c.OnDisconnect != nil {
204+
c.OnDisconnect(err)
205+
}
169206
}
170207
if err := c.closeConnection(); err != nil {
171208
logger.Errorw("Failed to close the connection after crashing", zap.Error(err))

websocket/connection_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,106 @@ func TestDurableConnectionSendsPingsRegularly(t *testing.T) {
419419
}
420420
}
421421

422+
func TestOnConnectAndOnDisconnectCallbacks(t *testing.T) {
423+
reconnectChan := make(chan struct{})
424+
425+
upgrader := websocket.Upgrader{}
426+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
427+
c, err := upgrader.Upgrade(w, r, nil)
428+
if err != nil {
429+
return
430+
}
431+
432+
// Wait for signal to drop the connection.
433+
<-reconnectChan
434+
c.Close()
435+
}))
436+
defer s.Close()
437+
438+
logger := ktesting.TestLogger(t)
439+
target := "ws" + strings.TrimPrefix(s.URL, "http")
440+
441+
onConnectCalled := make(chan struct{}, 10)
442+
onDisconnectCalled := make(chan error, 10)
443+
444+
conn := NewDurableSendingConnection(target, logger,
445+
WithOnConnect(func() {
446+
onConnectCalled <- struct{}{}
447+
}),
448+
WithOnDisconnect(func(err error) {
449+
onDisconnectCalled <- err
450+
}),
451+
)
452+
defer conn.Shutdown()
453+
454+
// Wait for the first OnConnect call
455+
select {
456+
case <-onConnectCalled:
457+
// Success - OnConnect was called
458+
case <-time.After(propagationTimeout):
459+
t.Fatal("Timed out waiting for OnConnect to be called")
460+
}
461+
462+
// Trigger a disconnect by closing the server-side connection
463+
reconnectChan <- struct{}{}
464+
465+
// Wait for OnDisconnect to be called
466+
select {
467+
case err := <-onDisconnectCalled:
468+
if err == nil {
469+
t.Error("Expected OnDisconnect to receive an error, got nil")
470+
}
471+
case <-time.After(propagationTimeout):
472+
t.Fatal("Timed out waiting for OnDisconnect to be called")
473+
}
474+
475+
// Wait for reconnection (OnConnect should be called again)
476+
select {
477+
case <-onConnectCalled:
478+
// Success - OnConnect was called again after reconnection
479+
case <-time.After(propagationTimeout):
480+
t.Fatal("Timed out waiting for OnConnect to be called after reconnection")
481+
}
482+
}
483+
484+
func TestOnConnectAndOnDisconnectCallbacksNotSet(t *testing.T) {
485+
reconnectChan := make(chan struct{})
486+
487+
upgrader := websocket.Upgrader{}
488+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
489+
c, err := upgrader.Upgrade(w, r, nil)
490+
if err != nil {
491+
return
492+
}
493+
494+
// Wait for signal to drop the connection.
495+
<-reconnectChan
496+
c.Close()
497+
}))
498+
defer s.Close()
499+
500+
logger := ktesting.TestLogger(t)
501+
target := "ws" + strings.TrimPrefix(s.URL, "http")
502+
503+
// Create connection without setting callbacks - should not panic
504+
conn := NewDurableSendingConnection(target, logger)
505+
defer conn.Shutdown()
506+
507+
// Wait for connection to be established
508+
err := wait.PollUntilContextTimeout(context.Background(), 50*time.Millisecond, propagationTimeout, true, func(ctx context.Context) (bool, error) {
509+
return conn.Status() == nil, nil
510+
})
511+
if err != nil {
512+
t.Fatal("Timed out waiting for connection to be established:", err)
513+
}
514+
515+
// Trigger disconnect - should not panic even without callbacks
516+
reconnectChan <- struct{}{}
517+
518+
// Wait a bit and verify no panic occurred
519+
time.Sleep(100 * time.Millisecond)
520+
}
521+
422522
func TestNewDurableSendingConnectionGuaranteed(t *testing.T) {
423523
// Unhappy case.
424524
logger := ktesting.TestLogger(t)

0 commit comments

Comments
 (0)