Skip to content

Commit 123dfde

Browse files
committed
Allow configuring a pipeline with a panic handler.
Clients set the panic handler at the pipeline level. The pipeline is then responsible for passing the handler to stages that support the PanicHandlerAware interface.
1 parent 3510c5f commit 123dfde

3 files changed

Lines changed: 67 additions & 20 deletions

File tree

pipe/panic.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ type PanicHandlerAware interface {
66
SetPanicHandler(PanicHandler)
77
}
88

9+
// PanicHandler is a function that handles panics in the pipeline and its stages.
910
type PanicHandler func(p any) error

pipe/pipeline.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type Pipeline struct {
6464
started uint32
6565

6666
eventHandler func(e *Event)
67+
panicHandler PanicHandler
6768
}
6869

6970
var emptyEventHandler = func(e *Event) {}
@@ -179,6 +180,19 @@ func WithEventHandler(handler func(e *Event)) Option {
179180
}
180181
}
181182

183+
// WithPanicHandler sets a panic handler for the pipeline, allowing clients to handle
184+
// and observe panics that occur during processing. When a process within the pipeline
185+
// panics, the provided handler will be invoked, enabling clients to capture the panic,
186+
// such as for observability purposes.
187+
//
188+
// Note: While the handler allows for additional processing of the panic (e.g., logging),
189+
// the panic will still propagate and be raised after the handler is executed.
190+
func WithPanicHandler(ph PanicHandler) Option {
191+
return func(p *Pipeline) {
192+
p.panicHandler = ph
193+
}
194+
}
195+
182196
func (p *Pipeline) hasStarted() bool {
183197
return atomic.LoadUint32(&p.started) != 0
184198
}
@@ -265,6 +279,12 @@ func (p *Pipeline) Start(ctx context.Context) error {
265279
}
266280

267281
for i, s := range p.stages {
282+
if p.panicHandler != nil {
283+
if phs, ok := s.(PanicHandlerAware); ok {
284+
phs.SetPanicHandler(p.panicHandler)
285+
}
286+
}
287+
268288
var err error
269289
stdout, err := s.Start(ctx, p.env, nextStdin)
270290
if err != nil {

pipe/pipeline_test.go

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -436,28 +436,54 @@ func TestFunction(t *testing.T) {
436436

437437
dir := t.TempDir()
438438

439-
p := pipe.New(pipe.WithDir(dir))
440-
p.Add(
441-
pipe.Print("hello world"),
442-
pipe.Function(
443-
"farewell",
444-
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
445-
buf, err := io.ReadAll(stdin)
446-
if err != nil {
439+
t.Run("successful function", func(t *testing.T) {
440+
p := pipe.New(pipe.WithDir(dir))
441+
p.Add(
442+
pipe.Print("hello world"),
443+
pipe.Function(
444+
"farewell",
445+
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
446+
buf, err := io.ReadAll(stdin)
447+
if err != nil {
448+
return err
449+
}
450+
if string(buf) != "hello world" {
451+
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
452+
}
453+
_, err = stdout.Write([]byte("goodbye, cruel world"))
447454
return err
448-
}
449-
if string(buf) != "hello world" {
450-
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
451-
}
452-
_, err = stdout.Write([]byte("goodbye, cruel world"))
453-
return err
454-
},
455-
),
456-
)
455+
},
456+
),
457+
)
457458

458-
out, err := p.Output(ctx)
459-
assert.NoError(t, err)
460-
assert.EqualValues(t, "goodbye, cruel world", out)
459+
out, err := p.Output(ctx)
460+
assert.NoError(t, err)
461+
assert.EqualValues(t, "goodbye, cruel world", out)
462+
})
463+
464+
t.Run("panic with handler", func(t *testing.T) {
465+
expectedErr := fmt.Errorf("recovered from panic: oh no!")
466+
p := pipe.New(
467+
pipe.WithDir(dir),
468+
pipe.WithPanicHandler(func(panicValue any) error {
469+
assert.Equal(t, "oh no!", panicValue)
470+
return expectedErr
471+
}),
472+
)
473+
p.Add(
474+
pipe.Print("hello world"),
475+
pipe.Function(
476+
"farewell",
477+
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
478+
panic("oh no!")
479+
},
480+
),
481+
)
482+
483+
out, err := p.Output(ctx)
484+
assert.ErrorIs(t, err, expectedErr)
485+
assert.Empty(t, out)
486+
})
461487
}
462488

463489
func TestPipelineWithFunction(t *testing.T) {

0 commit comments

Comments
 (0)