Skip to content

Commit b98e2a4

Browse files
feat(pam): support terminating active sessions (#167)
* feat: support killing active PAM sessions via gateway session registry Add a session registry to the gateway so that ALPN cancellation signals can find and close active proxy connections. When an admin terminates a session from the UI, the gateway now kills the proxy immediately instead of waiting for expiry. - Add pamSessions registry (map + mutex) to Gateway struct - Register/deregister sessions around HandlePAMProxy dispatch - HandlePAMCancellation now calls cancelSession to close the proxy conn - Per-session context so expiry timer exits cleanly on cancellation - HandleGatewayDisconnect on BaseProxyServer to exit the CLI proxy cleanly when the backend connection drops * fix: track multiple connections per session in gateway PAM registry The registry previously stored a single entry per session ID, so each new client connection (e.g., multiple psql windows) overwrote the previous one. On termination, only the last connection was killed. Change the registry to a slice per session ID so CancelPAMSession closes all active connections for the session. * fix: only trigger proxy shutdown on actual gateway errors HandleGatewayDisconnect was called unconditionally when errCh received a value, but io.Copy returns nil on clean EOF (normal client disconnect). This caused the entire proxy to shut down when a user simply exited psql. Only call HandleGatewayDisconnect when err != nil, so normal client disconnects don't kill the proxy. Also remove redundant shutdownCh select guard before shutdownOnce.Do since sync.Once already provides the once-only guarantee. * fix: distinguish gateway disconnect from normal client disconnect Split the shared errCh into separate gatewayErrCh and clientErrCh so we can detect which side closed the connection first. Only call HandleGatewayDisconnect when the gateway side drops — not when the client (e.g., psql) exits normally. Previously, HandleGatewayDisconnect was called unconditionally on any errCh value, which shut down the entire proxy when a user simply exited their client. With the err != nil guard, it was never called because both normal and admin-terminated closes produce nil from io.Copy. Also remove redundant shutdownCh select guard in HandleGatewayDisconnect since sync.Once already provides the once-only guarantee. * refactor: extract disconnect channel logic into base proxy helpers
1 parent 3933d3a commit b98e2a4

7 files changed

Lines changed: 123 additions & 41 deletions

File tree

packages/gateway-v2/gateway.go

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ type GatewayConfig struct {
8383
ReconnectDelay time.Duration
8484
}
8585

86+
type pamSessionEntry struct {
87+
cancel context.CancelFunc
88+
conn *tls.Conn
89+
}
90+
8691
type Gateway struct {
8792
GatewayID string
8893

@@ -110,6 +115,10 @@ type Gateway struct {
110115
heartbeatStarted bool
111116
heartbeatMu sync.Mutex
112117
notifyOnce sync.Once
118+
119+
// PAM session registry for active proxy connections (multiple connections per session)
120+
pamSessions map[string][]*pamSessionEntry
121+
pamSessionsMu sync.Mutex
113122
}
114123

115124
// NewGateway creates a new gateway instance
@@ -137,9 +146,51 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) {
137146
cancel: cancel,
138147
pamCredentialsManager: pamCredentialsManager,
139148
pamSessionUploader: session.NewSessionUploader(httpClient, pamCredentialsManager),
149+
pamSessions: make(map[string][]*pamSessionEntry),
140150
}, nil
141151
}
142152

153+
// RegisterPAMSession registers an active PAM proxy connection for cancellation support
154+
func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) {
155+
g.pamSessionsMu.Lock()
156+
defer g.pamSessionsMu.Unlock()
157+
g.pamSessions[sessionID] = append(g.pamSessions[sessionID], &pamSessionEntry{cancel: cancel, conn: conn})
158+
}
159+
160+
// DeregisterPAMSession removes a specific connection from the session registry
161+
func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) {
162+
g.pamSessionsMu.Lock()
163+
defer g.pamSessionsMu.Unlock()
164+
entries := g.pamSessions[sessionID]
165+
for i, e := range entries {
166+
if e.conn == conn {
167+
g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...)
168+
break
169+
}
170+
}
171+
if len(g.pamSessions[sessionID]) == 0 {
172+
delete(g.pamSessions, sessionID)
173+
}
174+
}
175+
176+
// CancelPAMSession kills all active connections for a PAM session
177+
func (g *Gateway) CancelPAMSession(sessionID string) bool {
178+
g.pamSessionsMu.Lock()
179+
entries, ok := g.pamSessions[sessionID]
180+
if ok {
181+
delete(g.pamSessions, sessionID)
182+
}
183+
g.pamSessionsMu.Unlock()
184+
if !ok {
185+
return false
186+
}
187+
for _, e := range entries {
188+
e.conn.Close()
189+
e.cancel()
190+
}
191+
return true
192+
}
193+
143194
func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) {
144195
sendHeartbeat := func() error {
145196
if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil {
@@ -608,7 +659,10 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) {
608659
}
609660
return
610661
} else if forwardConfig.Mode == ForwardModePAM {
611-
if err := pam.HandlePAMProxy(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil {
662+
sessionCtx, sessionCancel := context.WithCancel(g.ctx)
663+
g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn)
664+
defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn)
665+
if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil {
612666
if err.Error() == "unexpected EOF" {
613667
log.Debug().Err(err).Msg("PAM proxy handler ended with unexpected connection termination")
614668
} else {
@@ -617,7 +671,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) {
617671
}
618672
return
619673
} else if forwardConfig.Mode == ForwardModePAMCancellation {
620-
if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil {
674+
if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient, g.CancelPAMSession); err != nil {
621675
log.Error().Err(err).Msg("PAM cancellation proxy handler ended with error")
622676
}
623677
return

packages/pam/local/base-proxy.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,41 @@ func (b *BaseProxyServer) FallbackToAPITermination() {
262262
}
263263
}
264264

265+
// HandleGatewayDisconnect should be called when a gateway connection drops unexpectedly
266+
// (i.e., not initiated by the user via Ctrl+C). This happens when:
267+
// - An administrator terminates the session from the Infisical UI
268+
// - The session expires on the gateway side
269+
// - The gateway or relay goes down
270+
//
271+
// It prints a message and triggers proxy shutdown so the CLI process exits
272+
// cleanly instead of hanging with a dead backend connection.
273+
func (b *BaseProxyServer) HandleGatewayDisconnect() {
274+
b.shutdownOnce.Do(func() {
275+
fmt.Println("\nConnection to session lost. Shutting down proxy...")
276+
close(b.shutdownCh)
277+
b.cancel()
278+
})
279+
}
280+
281+
// NewDisconnectChannels creates the error channels used to distinguish gateway
282+
// disconnects from normal client disconnects.
283+
func (b *BaseProxyServer) NewDisconnectChannels() (gatewayErrCh, clientErrCh chan error) {
284+
return make(chan error, 1), make(chan error, 1)
285+
}
286+
287+
// WaitForDisconnect blocks until either the gateway or client side of a proxied
288+
// connection closes. If the gateway disconnects, the proxy shuts down.
289+
func (b *BaseProxyServer) WaitForDisconnect(gatewayErrCh, clientErrCh <-chan error, connCtx context.Context) {
290+
select {
291+
case <-gatewayErrCh:
292+
b.HandleGatewayDisconnect()
293+
case <-clientErrCh:
294+
// Normal client disconnect, proxy stays running
295+
case <-connCtx.Done():
296+
log.Info().Msg("Connection cancelled by context")
297+
}
298+
}
299+
265300
// WaitForConnectionsWithTimeout waits for active connections to close with a timeout
266301
func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) {
267302
done := make(chan struct{})

packages/pam/local/database-proxy.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) {
265265
connCtx, connCancel := context.WithCancel(p.ctx)
266266
defer connCancel()
267267

268-
errCh := make(chan error, 2)
268+
gatewayErrCh, clientErrCh := p.NewDisconnectChannels()
269269

270-
// Bidirectional data forwarding with context cancellation
270+
// Gateway → Client: if this side closes first, the gateway dropped the connection
271271
go func() {
272272
defer connCancel()
273273
_, err := io.Copy(clientConn, gatewayConn)
@@ -278,9 +278,10 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) {
278278
log.Debug().Err(err).Msg("Gateway to client copy ended")
279279
}
280280
}
281-
errCh <- err
281+
gatewayErrCh <- err
282282
}()
283283

284+
// Client → Gateway: if this side closes first, the client disconnected normally
284285
go func() {
285286
defer connCancel()
286287
_, err := io.Copy(gatewayConn, clientConn)
@@ -291,14 +292,10 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) {
291292
log.Debug().Err(err).Msg("Client to gateway copy ended")
292293
}
293294
}
294-
errCh <- err
295+
clientErrCh <- err
295296
}()
296297

297-
select {
298-
case <-errCh:
299-
case <-connCtx.Done():
300-
log.Info().Msg("Connection cancelled by context")
301-
}
298+
p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx)
302299

303300
log.Info().Msgf("Connection closed for client: %s", clientConn.RemoteAddr().String())
304301
}

packages/pam/local/kubernetes-proxy.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ func (p *KubernetesProxyServer) handleConnection(clientConn net.Conn) {
304304
connCtx, connCancel := context.WithCancel(p.ctx)
305305
defer connCancel()
306306

307-
errCh := make(chan error, 2)
307+
gatewayErrCh, clientErrCh := p.NewDisconnectChannels()
308308

309-
// Bidirectional data forwarding with context cancellation
309+
// Gateway → Client: if this side closes first, the gateway dropped the connection
310310
go func() {
311311
defer connCancel()
312312
_, err := io.Copy(clientConn, gatewayConn)
@@ -317,9 +317,10 @@ func (p *KubernetesProxyServer) handleConnection(clientConn net.Conn) {
317317
log.Debug().Err(err).Msg("Gateway to client copy ended")
318318
}
319319
}
320-
errCh <- err
320+
gatewayErrCh <- err
321321
}()
322322

323+
// Client → Gateway: if this side closes first, the client disconnected normally
323324
go func() {
324325
defer connCancel()
325326
_, err := io.Copy(gatewayConn, clientConn)
@@ -330,14 +331,10 @@ func (p *KubernetesProxyServer) handleConnection(clientConn net.Conn) {
330331
log.Debug().Err(err).Msg("Client to gateway copy ended")
331332
}
332333
}
333-
errCh <- err
334+
clientErrCh <- err
334335
}()
335336

336-
select {
337-
case <-errCh:
338-
case <-connCtx.Done():
339-
log.Info().Msg("Connection cancelled by context")
340-
}
337+
p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx)
341338

342339
log.Info().Msgf("Connection closed for client: %s", clientConn.RemoteAddr().String())
343340
}

packages/pam/local/redis-proxy.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ func (p *RedisProxyServer) handleConnection(clientConn net.Conn) {
250250
connCtx, connCancel := context.WithCancel(p.ctx)
251251
defer connCancel()
252252

253-
errCh := make(chan error, 2)
253+
gatewayErrCh, clientErrCh := p.NewDisconnectChannels()
254254

255-
// Bidirectional data forwarding with context cancellation
255+
// Gateway → Client: if this side closes first, the gateway dropped the connection
256256
go func() {
257257
defer connCancel()
258258
_, err := io.Copy(clientConn, gatewayConn)
@@ -263,9 +263,10 @@ func (p *RedisProxyServer) handleConnection(clientConn net.Conn) {
263263
log.Debug().Err(err).Msg("Gateway to client copy ended")
264264
}
265265
}
266-
errCh <- err
266+
gatewayErrCh <- err
267267
}()
268268

269+
// Client → Gateway: if this side closes first, the client disconnected normally
269270
go func() {
270271
defer connCancel()
271272
_, err := io.Copy(gatewayConn, clientConn)
@@ -276,14 +277,10 @@ func (p *RedisProxyServer) handleConnection(clientConn net.Conn) {
276277
log.Debug().Err(err).Msg("Client to gateway copy ended")
277278
}
278279
}
279-
errCh <- err
280+
clientErrCh <- err
280281
}()
281282

282-
select {
283-
case <-errCh:
284-
case <-connCtx.Done():
285-
log.Info().Msg("Connection cancelled by context")
286-
}
283+
p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx)
287284

288285
log.Info().Msgf("Connection closed for client: %s", clientConn.RemoteAddr().String())
289286
}

packages/pam/local/ssh-proxy.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,9 @@ func (p *SSHProxyServer) handleConnection(clientConn net.Conn) {
361361
connCtx, connCancel := context.WithCancel(p.ctx)
362362
defer connCancel()
363363

364-
errCh := make(chan error, 2)
364+
gatewayErrCh, clientErrCh := p.NewDisconnectChannels()
365365

366-
// Bidirectional data forwarding with context cancellation
367-
// Client (local SSH) → Gateway (SSH proxy)
366+
// Client (local SSH) → Gateway (SSH proxy): if this side closes first, the client disconnected normally
368367
go func() {
369368
defer connCancel()
370369
_, err := io.Copy(gatewayConn, clientConn)
@@ -375,10 +374,10 @@ func (p *SSHProxyServer) handleConnection(clientConn net.Conn) {
375374
log.Debug().Err(err).Msg("Client to gateway copy ended")
376375
}
377376
}
378-
errCh <- err
377+
clientErrCh <- err
379378
}()
380379

381-
// Gateway (SSH proxy) → Client (local SSH)
380+
// Gateway (SSH proxy) → Client (local SSH): if this side closes first, the gateway dropped the connection
382381
go func() {
383382
defer connCancel()
384383
_, err := io.Copy(clientConn, gatewayConn)
@@ -389,14 +388,10 @@ func (p *SSHProxyServer) handleConnection(clientConn net.Conn) {
389388
log.Debug().Err(err).Msg("Gateway to client copy ended")
390389
}
391390
}
392-
errCh <- err
391+
gatewayErrCh <- err
393392
}()
394393

395-
select {
396-
case <-errCh:
397-
case <-connCtx.Done():
398-
log.Debug().Msg("Connection cancelled by context")
399-
}
394+
p.WaitForDisconnect(gatewayErrCh, clientErrCh, connCtx)
400395

401396
log.Debug().Msgf("SSH connection closed for client: %s", clientConn.RemoteAddr().String())
402397
}

packages/pam/pam-proxy.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,16 @@ func HandlePAMCapabilities(ctx context.Context, conn *tls.Conn, gatewayName stri
7878
return nil
7979
}
8080

81-
func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMConfig, httpClient *resty.Client) error {
81+
func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMConfig, httpClient *resty.Client, cancelSession func(string) bool) error {
8282
log.Info().Str("sessionId", pamConfig.SessionId).Msg("Received session termination message")
8383

84+
// Kill the active proxy connection if it exists in the registry
85+
if cancelled := cancelSession(pamConfig.SessionId); cancelled {
86+
log.Info().Str("sessionId", pamConfig.SessionId).Msg("Active proxy session cancelled via registry")
87+
} else {
88+
log.Info().Str("sessionId", pamConfig.SessionId).Msg("No active proxy session found in registry (may have already ended)")
89+
}
90+
8491
if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil {
8592
log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session")
8693
}

0 commit comments

Comments
 (0)