Skip to content

Commit ebd78e0

Browse files
authored
[client] Update RaceDial to accept context for improved cancellation handling (#5849)
1 parent cf86b9a commit ebd78e0

3 files changed

Lines changed: 9 additions & 9 deletions

File tree

shared/relay/client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
333333
dialers := c.getDialers()
334334

335335
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
336-
conn, err := rd.Dial()
336+
conn, err := rd.Dial(ctx)
337337
if err != nil {
338338
return nil, err
339339
}

shared/relay/client/dialer/race_dialer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
4040
}
4141
}
4242

43-
func (r *RaceDial) Dial() (net.Conn, error) {
43+
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
4444
connChan := make(chan dialResult, len(r.dialerFns))
4545
winnerConn := make(chan net.Conn, 1)
46-
abortCtx, abort := context.WithCancel(context.Background())
46+
abortCtx, abort := context.WithCancel(ctx)
4747
defer abort()
4848

4949
for _, dfn := range r.dialerFns {

shared/relay/client/dialer/race_dialer_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
7878
serverURL := "test.server.com"
7979

8080
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
81-
conn, err := rd.Dial()
81+
conn, err := rd.Dial(context.Background())
8282
if err == nil {
8383
t.Errorf("Expected an error with empty dialers, got nil")
8484
}
@@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
104104
}
105105

106106
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
107-
conn, err := rd.Dial()
107+
conn, err := rd.Dial(context.Background())
108108
if err != nil {
109109
t.Errorf("Expected no error, got %v", err)
110110
}
@@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
137137
}
138138

139139
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
140-
conn, err := rd.Dial()
140+
conn, err := rd.Dial(context.Background())
141141
if err != nil {
142142
t.Errorf("Expected no error, got %v", err)
143143
}
@@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) {
160160
}
161161

162162
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
163-
conn, err := rd.Dial()
163+
conn, err := rd.Dial(context.Background())
164164
if err == nil {
165165
t.Errorf("Expected an error, got nil")
166166
}
@@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
188188
}
189189

190190
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
191-
conn, err := rd.Dial()
191+
conn, err := rd.Dial(context.Background())
192192
if err == nil {
193193
t.Errorf("Expected an error, got nil")
194194
}
@@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
230230
}
231231

232232
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
233-
conn, err := rd.Dial()
233+
conn, err := rd.Dial(context.Background())
234234
if err != nil {
235235
t.Errorf("Expected no error, got %v", err)
236236
}

0 commit comments

Comments
 (0)