Skip to content

Commit 23e9018

Browse files
znullCopilot
andcommitted
Port MemoryLimitWithObserver to new Stage interface
MemoryLimitWithObserver was added to main (PR #48) after the version-2 branch diverged. Port it to the new Stage interface and add test coverage. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 110d512 commit 23e9018

2 files changed

Lines changed: 142 additions & 0 deletions

File tree

pipe/memorylimit.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,99 @@ func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc
8989
}
9090
}
9191

92+
// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver in
93+
// one goroutine. It watches the memory usage of the stage, stops it
94+
// if it exceeds the given limit, and logs the peak memory usage when
95+
// the stage exits.
96+
func MemoryLimitWithObserver(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage {
97+
limitableStage, ok := stage.(LimitableStage)
98+
if !ok {
99+
eventHandler(&Event{
100+
Command: stage.Name(),
101+
Msg: "invalid pipe.MemoryLimitWithObserver usage",
102+
Err: fmt.Errorf("invalid pipe.MemoryLimitWithObserver usage"),
103+
})
104+
return stage
105+
}
106+
107+
return &memoryWatchStage{
108+
nameSuffix: " with memory limit",
109+
stage: limitableStage,
110+
watch: killAtLimitAndObserve(byteLimit, eventHandler),
111+
}
112+
}
113+
114+
func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc {
115+
return func(ctx context.Context, stage LimitableStage) {
116+
var (
117+
maxRSS uint64
118+
samples, errCount, consecutiveErrors int
119+
killed bool
120+
)
121+
122+
t := time.NewTicker(memoryPollInterval)
123+
defer t.Stop()
124+
125+
for {
126+
select {
127+
case <-ctx.Done():
128+
eventHandler(&Event{
129+
Command: stage.Name(),
130+
Msg: "peak memory usage",
131+
Context: map[string]interface{}{
132+
"max_rss_bytes": maxRSS,
133+
"samples": samples,
134+
"errors": errCount,
135+
},
136+
})
137+
return
138+
case <-t.C:
139+
if killed {
140+
continue
141+
}
142+
143+
rss, err := stage.GetRSSAnon(ctx)
144+
if err != nil {
145+
if !errors.Is(err, errProcessInfoMissing) {
146+
errCount++
147+
consecutiveErrors++
148+
if consecutiveErrors == 2 {
149+
eventHandler(&Event{
150+
Command: stage.Name(),
151+
Msg: "error getting RSS",
152+
Err: err,
153+
})
154+
}
155+
} else {
156+
consecutiveErrors = 0
157+
}
158+
continue
159+
}
160+
161+
consecutiveErrors = 0
162+
samples++
163+
if rss > maxRSS {
164+
maxRSS = rss
165+
}
166+
167+
if rss >= byteLimit {
168+
eventHandler(&Event{
169+
Command: stage.Name(),
170+
Msg: "stage exceeded allowed memory use",
171+
Err: fmt.Errorf("stage exceeded allowed memory use"),
172+
Context: map[string]interface{}{
173+
"limit": byteLimit,
174+
"used": rss,
175+
},
176+
})
177+
stage.Kill(ErrMemoryLimitExceeded)
178+
killed = true
179+
}
180+
}
181+
}
182+
}
183+
}
184+
92185
// MemoryObserver watches memory use of the stage and logs the maximum
93186
// value when the stage exits.
94187
func MemoryObserver(stage Stage, eventHandler func(e *Event)) Stage {

pipe/memorylimit_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,22 @@ func TestMemoryLimitTreeMem(t *testing.T) {
113113
require.ErrorContains(t, err, "memory limit exceeded")
114114
}
115115

116+
func TestMemoryLimitWithObserverSimple(t *testing.T) {
117+
t.Parallel()
118+
msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less"))
119+
assert.Contains(t, msg, "exceeded allowed memory")
120+
assert.Contains(t, msg, "limit=10000000")
121+
require.ErrorContains(t, err, "memory limit exceeded")
122+
}
123+
124+
func TestMemoryLimitWithObserverTreeMem(t *testing.T) {
125+
t.Parallel()
126+
msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :"))
127+
assert.Contains(t, msg, "exceeded allowed memory")
128+
assert.Contains(t, msg, "limit=10000000")
129+
require.ErrorContains(t, err, "memory limit exceeded")
130+
}
131+
116132
func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) {
117133
ctx := context.Background()
118134

@@ -147,3 +163,36 @@ func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (str
147163

148164
return buf.String(), err
149165
}
166+
167+
func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) {
168+
ctx := context.Background()
169+
170+
devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0)
171+
require.NoError(t, err)
172+
173+
buf := &bytes.Buffer{}
174+
logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime)
175+
176+
p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull))
177+
p.Add(
178+
pipe.Function(
179+
"write-to-less",
180+
func(ctx context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error {
181+
var bytes [1_000_000]byte
182+
for i := 0; i < mbs; i++ {
183+
_, err := stdout.Write(bytes[:])
184+
if err != nil {
185+
require.ErrorIs(t, err, syscall.EPIPE)
186+
}
187+
}
188+
return nil
189+
},
190+
),
191+
pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger)),
192+
)
193+
require.NoError(t, p.Start(ctx))
194+
195+
err = p.Wait()
196+
197+
return buf.String(), err
198+
}

0 commit comments

Comments
 (0)