77 ERROR_CLASSIFICATION_PROMPT ,
88)
99from agentlab .analyze .inspect_results import summarize
10- from agentlab .llm .llm_utils import json_parser
10+ from agentlab .llm .llm_utils import json_parser , parse_html_tags
1111
1212
1313def _diff (past_obs , current_obs ):
@@ -39,7 +39,7 @@ def summarize(self, obs: StepInfo, next_obs: StepInfo, past_summaries: list[str]
3939 if self .use_diff :
4040 next_obs_message = _diff (obs_message , next_obs_message )
4141
42- return self .llm (
42+ return self .parse ( self . llm (
4343 self .make_prompt (
4444 obs_message ,
4545 action ,
@@ -48,7 +48,7 @@ def summarize(self, obs: StepInfo, next_obs: StepInfo, past_summaries: list[str]
4848 goal ,
4949 obs .obs .get ("plan" , "No plan available" ),
5050 )
51- )
51+ )[ 'content' ])
5252
5353 def make_prompt (
5454 self , past_obs_message , action , current_obs_message , past_summaries , goal , plan
@@ -63,6 +63,10 @@ def make_prompt(
6363 action = action ,
6464 )
6565
66+ def parse (self , raw_output : str ) -> dict :
67+ parsed_result = parse_html_tags (raw_output , keys = ["changeSummary" , "actionAssessment" , "explanation" , "suggestion" ])[0 ]
68+ return parsed_result
69+
6670
6771@dataclass
6872class EpisodeAnalysis :
@@ -83,13 +87,13 @@ def make_prompt(self, exp_results: ExpResult, summaries: list[str]): ...
8387 def __call__ (self , exp_results : ExpResult ) -> EpisodeAnalysis :
8488 """Run Change Summarizer for every step in the episode or extract a pre-computed one."""
8589
86- if exp_results .steps_info [- 1 ].reward == 1 :
87- return {"analysis" : "Success" , "summaries" : {}}
90+ # if exp_results.steps_info[-1].reward == 1:
91+ # return {"analysis": "Success", "summaries": {}}
8892
8993 summaries = self .make_change_summaries (exp_results )
9094 prompt = self .make_prompt (exp_results , summaries )
9195 raw_analysis = self .llm (prompt )["content" ]
92- analysis = self .parser (raw_analysis )
96+ analysis = self .parse (raw_analysis )
9397 return {
9498 "analysis" : analysis ,
9599 "summaries" : {i : self .parser (a ) for i , a in enumerate (summaries )},
@@ -102,10 +106,13 @@ def make_change_summaries(self, exp_result: ExpResult) -> list[str]:
102106 # TODO:(thibault) make some checks or w/e
103107 for step , next_step in zip (exp_result .steps_info [:- 1 ], exp_result .steps_info [1 :]):
104108 summaries .append (
105- self .change_summarizer .summarize (step , next_step , summaries )[ "content" ]
109+ self .change_summarizer .summarize (step , next_step , summaries )
106110 )
107111 return summaries
108112
113+ def parse (self , raw_output : str ) -> dict :
114+ parsed_result = parse_html_tags (raw_output , keys = ["explanation" , "success" , "errorCategory" ])[0 ]
115+ return parsed_result
109116
110117@dataclass
111118class EpisodeErrorSummarizer (EpisodeSummarizer ):
@@ -116,7 +123,13 @@ def make_prompt(self, exp_results: ExpResult, summaries: list[str]):
116123 """TODO: Implement the prompt."""
117124 goal = exp_results .steps_info [0 ].obs ["goal" ]
118125
119- txt_summaries = "\n " .join (summaries )
126+ def format_summary (summary ):
127+ res = ''
128+ for key , value in summary .items ():
129+ res += f"{ key } : { value } \n "
130+ return res
131+
132+ txt_summaries = "\n " .join ([format_summary (summary ) for summary in summaries ])
120133
121134 thoughts = [step .agent_info .think for step in exp_results .steps_info [:- 1 ]]
122135 actions = [step .action for step in exp_results .steps_info [:- 1 ]]
0 commit comments