Skip to content

Commit 31e5bf5

Browse files
committed
added pipeline and tests
1 parent 9f531cc commit 31e5bf5

36 files changed

Lines changed: 1493 additions & 2 deletions

File tree

src/agentlab/agents/generic_agent/agent_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@
257257
)
258258

259259
AGENT_4o_MINI = GenericAgentArgs(
260-
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
260+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini-2024-07-18"],
261261
flags=FLAGS_GPT_4o,
262262
)
263263

src/agentlab/analyze/error_analysis/__init__.py

Whitespace-only changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
import re
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
from typing import Generator
6+
7+
from bgym import ExpResult
8+
9+
from agentlab.analyze.inspect_results import yield_all_exp_results
10+
11+
from .summarizer import ChangeSummarizer, EpisodeSummarizer
12+
13+
14+
@dataclass
15+
class Analyzer:
16+
prompt: str
17+
llm = None
18+
19+
def __call__(self, *args, **kwds):
20+
return "analysis"
21+
22+
23+
@dataclass
24+
class ErrorAnalysisPipeline:
25+
exp_dir: Path
26+
filter: str = None
27+
step_summarizer: ChangeSummarizer = None
28+
episode_summarizer: EpisodeSummarizer = None
29+
analyzer: Analyzer = None
30+
31+
def filter_exp_results(self) -> Generator[ExpResult, None, None]:
32+
# TODO:(thibault) improve filtering
33+
exp_results = yield_all_exp_results(self.exp_dir)
34+
for exp_result in exp_results:
35+
if self.filter is None or self.filter in str(exp_result.exp_dir):
36+
yield exp_result
37+
38+
def run_analysis(self):
39+
filtered_results = self.filter_exp_results()
40+
41+
for exp_result in filtered_results:
42+
step_analysis = self.analyze_step(exp_result)
43+
episode_analysis = self.analyze_episode(exp_result, step_analysis)
44+
error_analysis = self.analyze_errors(exp_result, episode_analysis, step_analysis)
45+
self.save_analysis(exp_result, error_analysis)
46+
47+
def analyze_step(self, exp_result: ExpResult) -> list[str]:
48+
step_summaries = [] # type: list[str]
49+
# this assumes that there is always an extra step at the end of the episode
50+
# it is generally the case, but exps can sometimes fail in a weird way and not save the last step_info
51+
# TODO:(thibault) make some checks
52+
for step, next_step in zip(exp_result.steps_info[:-1], exp_result.steps_info[1:]):
53+
step_summaries.append(
54+
self.step_summarizer.summarize(step, step.action, next_step, step_summaries)
55+
)
56+
return step_summaries
57+
58+
def analyze_episode(self, exp_result: ExpResult, step_analysis: list[str]) -> str:
59+
episode_summary = self.episode_summarizer.summarize(exp_result, step_analysis)
60+
return episode_summary
61+
62+
def analyze_errors(
63+
self, exp_result: ExpResult, episode_analysis: str, step_analysis: list[str]
64+
) -> str:
65+
error_analysis = self.analyzer(exp_result, episode_analysis, step_analysis)
66+
return error_analysis
67+
68+
def save_analysis(self, exp_result: ExpResult, error_analysis: dict, exists_ok=True):
69+
"""Save the analysis to json"""
70+
analysis_path = exp_result.exp_dir / "error_analysis.json"
71+
if not exists_ok and analysis_path.exists():
72+
raise FileExistsError(f"{analysis_path} already exists")
73+
with analysis_path.open("w") as f:
74+
json.dump(error_analysis, f)

src/agentlab/analyze/error_analysis.py renamed to src/agentlab/analyze/error_analysis/summarizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
23
from bgym import StepInfo
34

45
CHANGE_SUMMARIZER_PROMPT = """
@@ -227,7 +228,7 @@ def summarize(
227228
past_obs_message = self.obs_formatter(past_obs)
228229
current_obs_message = self.obs_formatter(current_obs)
229230

230-
goal = past_obs["goal"] # Use goal object from agentlab
231+
goal = past_obs["goal"] # Use goal object from agentlab
231232
# Outsource everything to formatter
232233
plan = past_obs["plan"]
233234
if self.use_diff:
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
from bgym import ExpResult, StepInfo
5+
6+
from agentlab.analyze.error_analysis.pipeline import ErrorAnalysisPipeline
7+
8+
exp_dir = Path(__file__).parent.parent.parent / "data/error_analysis"
9+
10+
11+
class MockStepSummarizer:
12+
def summarize(
13+
self, step: StepInfo, action: str, next_step: StepInfo, step_summaries: list[str]
14+
) -> str:
15+
return f"Agent took action {action} at step {len(step_summaries)}"
16+
17+
18+
class MockEpisodeSummarizer:
19+
def summarize(self, exp_result: ExpResult, step_analysis: list[str]) -> str:
20+
return f"Agent did actions {', '.join(step.action for step in exp_result.steps_info if step.action)}"
21+
22+
23+
class MockAnalyzer:
24+
def __call__(
25+
self, exp_result: ExpResult, episode_analysis: str, step_analysis: list[str]
26+
) -> str:
27+
return {"error": "analysis", "episode": episode_analysis}
28+
29+
30+
@pytest.fixture(scope="module")
31+
def pipeline() -> ErrorAnalysisPipeline:
32+
return ErrorAnalysisPipeline(
33+
exp_dir=exp_dir,
34+
filter=None,
35+
episode_summarizer=MockEpisodeSummarizer(),
36+
step_summarizer=MockStepSummarizer(),
37+
analyzer=MockAnalyzer(),
38+
)
39+
40+
41+
def test_yield_no_filter(pipeline: ErrorAnalysisPipeline):
42+
assert len(list(pipeline.filter_exp_results())) == 4
43+
44+
45+
def test_yield_with_filter(pipeline: ErrorAnalysisPipeline):
46+
pattern = "click-dialog"
47+
pipeline.filter = pattern
48+
assert len(list(pipeline.filter_exp_results())) == 2
49+
pipeline.filter = None
50+
51+
52+
def test_analyze_step(pipeline: ErrorAnalysisPipeline):
53+
exp_result = next(pipeline.filter_exp_results())
54+
step_analysis = pipeline.analyze_step(exp_result)
55+
56+
assert len(exp_result.steps_info) == len(step_analysis) + 1
57+
assert step_analysis[0] == f"Agent took action {exp_result.steps_info[0].action} at step 0"
58+
59+
60+
def test_analyze_episode(pipeline: ErrorAnalysisPipeline):
61+
exp_result = next(pipeline.filter_exp_results())
62+
step_analysis = pipeline.analyze_step(exp_result)
63+
episode_analysis = pipeline.analyze_episode(exp_result, step_analysis)
64+
65+
for step_info in exp_result.steps_info:
66+
if step_info.action:
67+
assert step_info.action in episode_analysis
68+
69+
70+
def test_save_analysis(pipeline: ErrorAnalysisPipeline):
71+
exp_result = next(pipeline.filter_exp_results())
72+
step_analysis = pipeline.analyze_step(exp_result)
73+
episode_analysis = pipeline.analyze_episode(exp_result, step_analysis)
74+
error_analysis = pipeline.analyze_errors(exp_result, episode_analysis, step_analysis)
75+
76+
pipeline.save_analysis(exp_result, error_analysis, exists_ok=False)
77+
78+
assert (exp_result.exp_dir / "error_analysis.json").exists()
79+
80+
# remove the file
81+
(exp_result.exp_dir / "error_analysis.json").unlink()
82+
83+
84+
if __name__ == "__main__":
85+
test_yield_with_filter()

0 commit comments

Comments
 (0)