Skip to content

Commit e67cf2a

Browse files
znullCopilot
andcommitted
Restore panic handler for Function stages
The version-2 Stage interface redesign dropped the panic recovery from goStage. Restore it: add SetPanicHandler(), recoverPanic(), and WithStagePanicHandler pipeline option. The goroutine uses stacked defers (close(done) → close stdin → close stdout → recoverPanic) so that when a Function panics, recoverPanic fires first (sets s.err), then cleanup runs, then done closes — allowing Wait() to return the caught panic error. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 23e9018 commit e67cf2a

3 files changed

Lines changed: 75 additions & 17 deletions

File tree

pipe/function.go

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,26 @@ func Function(name string, f StageFunc) Stage {
3232
// goStage is a `Stage` that does its work by running an arbitrary
3333
// `stageFunc` in a goroutine.
3434
type goStage struct {
35-
name string
36-
f StageFunc
37-
done chan struct{}
38-
err error
35+
name string
36+
f StageFunc
37+
done chan struct{}
38+
err error
39+
panicHandler StagePanicHandler
3940
}
4041

4142
var (
42-
_ Stage = (*goStage)(nil)
43+
_ Stage = (*goStage)(nil)
44+
_ StagePanicHandlerAware = (*goStage)(nil)
4345
)
4446

4547
func (s *goStage) Name() string {
4648
return s.name
4749
}
4850

51+
func (s *goStage) SetPanicHandler(ph StagePanicHandler) {
52+
s.panicHandler = ph
53+
}
54+
4955
func (s *goStage) Preferences() StagePreferences {
5056
return StagePreferences{
5157
StdinPreference: IOPreferenceUndefined,
@@ -67,21 +73,24 @@ func (s *goStage) Start(
6773
}
6874

6975
go func() {
70-
s.err = s.f(ctx, env, r, w)
71-
72-
if stdout != nil {
73-
if err := stdout.Close(); err != nil && s.err == nil {
74-
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
76+
defer close(s.done)
77+
defer func() {
78+
if stdin != nil {
79+
if err := stdin.Close(); err != nil && s.err == nil {
80+
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
81+
}
7582
}
76-
}
77-
78-
if stdin != nil {
79-
if err := stdin.Close(); err != nil && s.err == nil {
80-
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
83+
}()
84+
defer func() {
85+
if stdout != nil {
86+
if err := stdout.Close(); err != nil && s.err == nil {
87+
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
88+
}
8189
}
82-
}
90+
}()
91+
defer s.recoverPanic()
8392

84-
close(s.done)
93+
s.err = s.f(ctx, env, r, w)
8594
}()
8695

8796
return nil
@@ -91,3 +100,13 @@ func (s *goStage) Wait() error {
91100
<-s.done
92101
return s.err
93102
}
103+
104+
func (s *goStage) recoverPanic() {
105+
if s.panicHandler == nil {
106+
return
107+
}
108+
109+
if p := recover(); p != nil {
110+
s.err = s.panicHandler(p)
111+
}
112+
}

pipe/pipeline.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ type Pipeline struct {
6565
started uint32
6666

6767
eventHandler func(e *Event)
68+
panicHandler StagePanicHandler
6869
}
6970

7071
var emptyEventHandler = func(e *Event) {}
@@ -216,6 +217,15 @@ func WithEventHandler(handler func(e *Event)) Option {
216217
}
217218
}
218219

220+
// WithStagePanicHandler sets a panic handler for the stages within a pipeline.
221+
// When a pipeline stage panics, the provided handler will be invoked, allowing
222+
// the client to handle the panic in whatever way they see fit.
223+
func WithStagePanicHandler(ph StagePanicHandler) Option {
224+
return func(p *Pipeline) {
225+
p.panicHandler = ph
226+
}
227+
}
228+
219229
func (p *Pipeline) hasStarted() bool {
220230
return atomic.LoadUint32(&p.started) != 0
221231
}
@@ -315,6 +325,10 @@ func (p *Pipeline) Start(ctx context.Context) error {
315325
ss := &stageStarters[i]
316326
nextSS := &stageStarters[i+1]
317327

328+
if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil {
329+
phs.SetPanicHandler(p.panicHandler)
330+
}
331+
318332
// We need to generate a pipe pair for this stage to use
319333
// to communicate with its successor:
320334
if ss.prefs.StdoutPreference == IOPreferenceFile ||
@@ -343,6 +357,10 @@ func (p *Pipeline) Start(ctx context.Context) error {
343357
s := p.stages[i]
344358
ss := &stageStarters[i]
345359

360+
if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil {
361+
phs.SetPanicHandler(p.panicHandler)
362+
}
363+
346364
if err := s.Start(ctx, p.env, ss.stdin, ss.stdout); err != nil {
347365
return abort(i, err)
348366
}

pipe/pipeline_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,27 @@ func TestFunction(t *testing.T) {
457457
assert.NoError(t, err)
458458
assert.EqualValues(t, "goodbye, cruel world", out)
459459
})
460+
461+
t.Run("panic with handler", func(t *testing.T) {
462+
p := pipe.New(
463+
pipe.WithStagePanicHandler(func(p any) error {
464+
return fmt.Errorf("panic handled: %v", p)
465+
}),
466+
)
467+
p.Add(
468+
pipe.Print("hello world"),
469+
pipe.Function(
470+
"farewell",
471+
func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error {
472+
panic("this is a panic")
473+
},
474+
),
475+
)
476+
477+
out, err := p.Output(ctx)
478+
assert.ErrorContains(t, err, "panic handled")
479+
assert.Empty(t, out)
480+
})
460481
}
461482

462483
func TestPipelineWithFunction(t *testing.T) {

0 commit comments

Comments
 (0)