Skip to content

Commit a455d0d

Browse files
add error analysis prediction validation script
1 parent a1f3416 commit a455d0d

1 file changed

Lines changed: 31 additions & 0 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
from agentlab.analyze.inspect_results import (
3+
load_result_df,
4+
)
5+
import json
6+
7+
8+
def get_aggregate_statistics(exp_dir: Path):
9+
"""Get aggregate statistics for the experiment results."""
10+
results = load_result_df(exp_dir, filter=filter)
11+
12+
13+
if __name__ == "__main__":
14+
path = Path(
15+
"/mnt/colab_public/data/ui_copilot/thibault/tmlr_exps/2024-10-23_14-17-47_5_agents_on_workarena_l1"
16+
)
17+
results = load_result_df(path).reset_index()
18+
results = results.loc[results["agent.chat_model.model_name"].str.contains("anthropic")]
19+
success_predictions = []
20+
for dir in results["exp_dir"]:
21+
error_analysis = Path(dir) / "error_analysis.json"
22+
if error_analysis.exists():
23+
with open(error_analysis, "r") as f:
24+
error_analysis = json.load(f)
25+
task_success_prediction_str = error_analysis["analysis"]["success"]
26+
task_success_prediction = True if task_success_prediction_str == "True" else False
27+
success_predictions.append(task_success_prediction)
28+
else:
29+
success_predictions.append(None)
30+
results["success_predictions"] = success_predictions
31+
a = 1

0 commit comments

Comments
 (0)