@@ -39,16 +39,18 @@ def summarize(self, obs: StepInfo, next_obs: StepInfo, past_summaries: list[str]
3939 if self .use_diff :
4040 next_obs_message = _diff (obs_message , next_obs_message )
4141
42- return self .parse (self .llm (
43- self .make_prompt (
44- obs_message ,
45- action ,
46- next_obs_message ,
47- past_summaries ,
48- goal ,
49- obs .obs .get ("plan" , "No plan available" ),
50- )
51- )['content' ])
42+ return self .parse (
43+ self .llm (
44+ self .make_prompt (
45+ obs_message ,
46+ action ,
47+ next_obs_message ,
48+ past_summaries ,
49+ goal ,
50+ obs .obs .get ("plan" , "No plan available" ),
51+ )
52+ )["content" ]
53+ )
5254
5355 def make_prompt (
5456 self , past_obs_message , action , current_obs_message , past_summaries , goal , plan
@@ -64,7 +66,9 @@ def make_prompt(
6466 )
6567
6668 def parse (self , raw_output : str ) -> dict :
67- parsed_result = parse_html_tags (raw_output , keys = ["changeSummary" , "actionAssessment" , "explanation" , "suggestion" ])[0 ]
69+ parsed_result = parse_html_tags (
70+ raw_output , keys = ["changeSummary" , "actionAssessment" , "explanation" , "suggestion" ]
71+ )[0 ]
6872 return parsed_result
6973
7074
@@ -105,15 +109,16 @@ def make_change_summaries(self, exp_result: ExpResult) -> list[str]:
105109 # it is generally the case, but exps can sometimes fail in a weird way and not save the last step_info
106110 # TODO:(thibault) make some checks or w/e
107111 for step , next_step in zip (exp_result .steps_info [:- 1 ], exp_result .steps_info [1 :]):
108- summaries .append (
109- self .change_summarizer .summarize (step , next_step , summaries )
110- )
112+ summaries .append (self .change_summarizer .summarize (step , next_step , summaries ))
111113 return summaries
112114
113115 def parse (self , raw_output : str ) -> dict :
114- parsed_result = parse_html_tags (raw_output , keys = ["explanation" , "success" , "errorCategory" ])[0 ]
116+ parsed_result = parse_html_tags (
117+ raw_output , keys = ["explanation" , "success" , "errorCategory" ]
118+ )[0 ]
115119 return parsed_result
116120
121+
117122@dataclass
118123class EpisodeErrorSummarizer (EpisodeSummarizer ):
119124
@@ -124,7 +129,7 @@ def make_prompt(self, exp_results: ExpResult, summaries: list[str]):
124129 goal = exp_results .steps_info [0 ].obs ["goal" ]
125130
126131 def format_summary (summary ):
127- res = ''
132+ res = ""
128133 for key , value in summary .items ():
129134 res += f"{ key } : { value } \n "
130135 return res
0 commit comments