33from dataclasses import dataclass
44from pathlib import Path
55
6+ from browsergym .core .task import AbstractBrowserTask
7+
68from agentlab .actions import ToolCall , ToolsActionSet , ToolSpec
79from agentlab .backends .browser .base import BrowserBackend
810from agentlab .benchmarks .abstract_env import AbstractEnv , AbstractEnvArgs
@@ -15,30 +17,42 @@ def final_step():
1517 """
1618 Finish the task execution.
1719 """
18- pass
20+ return {
21+ "pruned_html" : "Task finished" ,
22+ "axtree_txt" : "" ,
23+ "last_action_error" : "" ,
24+ "focused_element_bid" : "none" ,
25+ }
1926
2027
2128class BrowserEnv (AbstractEnv ):
2229 def __init__ (
23- self , task_name : str , task : AbstractWebTask , backend : BrowserBackend , seed : int = 0
30+ self , task_name : str , task : AbstractWebTask | AbstractBrowserTask , backend : BrowserBackend , seed : int = 0
2431 ):
2532 self .task_name = task_name
2633 self .task = task
2734 self .seed = seed
2835 self ._turns = 0
29- self .max_turns = task .max_turns
3036 self .backend = backend
3137 self .backend .initialize ()
3238 self .goal = ""
39+ if isinstance (self .task , AbstractBrowserTask ) and not self .backend .has_pw_page :
40+ raise ValueError (
41+ "Legacy task requires a backend with direct playwright page access."
42+ )
3343
3444 def reset (self , seed : int ):
3545 self .seed = seed
36- logger .info (f"Open task URL: { self .task .url } " )
37- self .backend .goto (self .task .url )
38- setup_js = self .task .get_setup_js ()
39- if setup_js :
40- self .goal = self .task .parse_setup_result (self .backend .run_js (setup_js ))
41- logger .info (f"Task goal: { self .goal } " )
46+ if isinstance (self .task , AbstractBrowserTask ):
47+ self .goal , task_info = self .task .setup (page = self .backend .page )
48+ obs = self ._get_obs ()
49+ else :
50+ self .goal , task_info = self .task .setup (backend = self .backend )
51+ obs = self ._get_obs ()
52+ obs = self .task .obs_postprocess (obs )
53+ return obs , task_info
54+
55+ def _get_obs (self ) -> dict :
4256 html = self .backend .page_html ()
4357 screenshot = self .backend .page_screenshot ()
4458 axtree = self .backend .page_axtree ()
@@ -50,80 +64,72 @@ def reset(self, seed: int):
5064 "last_action_error" : "" ,
5165 "focused_element_bid" : "none" ,
5266 }
53- obs = self .task .obs_postprocess (obs )
54- return obs , {}
67+ return obs
5568
5669 def step (self , action : ToolCall | str ) -> tuple [dict , float , bool , bool , dict ]:
5770 if isinstance (action , str ):
5871 action = ToolsActionSet .parse_action (action )
5972 logger .info (f"BrowserEnv.step() called with action { action } " )
6073
6174 action_exec_start = time .time ()
62- finished = action .name == "final_step"
63- if finished :
64- observation = {
65- "goal_object" : [{"type" : "text" , "text" : self .goal }],
66- "pruned_html" : "Task finished" ,
67- "axtree_txt" : "" ,
68- "last_action_error" : "" ,
69- "focused_element_bid" : "none" ,
70- }
75+ done = action .name == "final_step"
76+ if done :
77+ observation = final_step ()
7178 else :
72- observation = self ._step (action )
73- observation = self .task .obs_postprocess (observation )
74-
79+ observation = self .backend .step (action )
7580 action_exec_stop = time .time ()
7681 self ._turns += 1
77- truncated = self ._turns >= self .max_turns
82+ if isinstance (self .task , AbstractWebTask ):
83+ truncated = self ._turns >= self .task .max_turns
84+ else :
85+ truncated = False
7886
79- if self .task .validate_per_step or finished or truncated :
80- reward , other = self .validate_task (action , observation )
81- if other .get ("done" , False ):
82- finished = True
87+ observation = self .obs_postprocess (observation )
88+
89+ if isinstance (self .task , AbstractBrowserTask ):
90+ reward , done , _ , info = self .task .validate (page = self .backend .page , chat_messages = [])
91+ elif self .task .validate_per_step or done or truncated :
92+ reward , info = self .task .validate ()
93+ if info .get ("done" , False ):
94+ done = True
8395 else :
8496 reward = 0.0
85- other = {}
97+ info = {}
8698
8799 env_info = {
100+ ** info ,
88101 "action_exec_start" : action_exec_start ,
89102 "action_exec_stop" : action_exec_stop ,
90- "action_exec_timeout" : 0.0 ,
91- } | other
103+ "action_exec_timeout" : 0.0
104+ }
92105 logger .info (f"Action result in observation: { observation } " )
93- return observation , reward , finished , truncated , env_info
94-
95- def _step (self , action : ToolCall ) -> dict :
96- obs_dict = self .backend .step (action )
97- if "goal_object" not in obs_dict :
98- obs_dict ["goal_object" ] = [{"type" : "text" , "text" : self .goal }]
99- if "last_action_error" not in obs_dict :
100- obs_dict ["last_action_error" ] = ""
101- if "focused_element_bid" not in obs_dict :
102- obs_dict ["focused_element_bid" ] = "none"
103- return obs_dict
104-
105- def validate_task (self , action : ToolCall , observation : dict ) -> tuple [float , dict ]:
106- validate_js = self .task .get_step_validate_js ()
107- validate_result = self .backend .run_js (validate_js )
108- reward , other = self .task .parse_validation_result (validate_result )
109- return reward , other
106+ return observation , reward , done , truncated , env_info
107+
108+ def obs_postprocess (self , obs : dict ) -> dict :
109+ if "goal_object" not in obs :
110+ obs ["goal_object" ] = [{"type" : "text" , "text" : self .goal }]
111+ if "last_action_error" not in obs :
112+ obs ["last_action_error" ] = ""
113+ if "focused_element_bid" not in obs :
114+ obs ["focused_element_bid" ] = "none"
115+ if isinstance (self .task , AbstractWebTask ):
116+ obs = self .task .obs_postprocess (obs )
117+ return obs
110118
111119 def close (self ):
112- teardown_js = self .task .get_teardown_js ()
113- if teardown_js :
114- js_result_str = self .backend .run_js (teardown_js )
115- logger .info (f"Task teardown result: { js_result_str } " )
116- self .backend .close ()
120+ self .task .teardown ()
117121
118122 def actions (self ) -> list [ToolSpec ]:
119123 all_actions = self .backend .actions ()
120- filtered_actions = self .task .filter_actions (all_actions )
121- logger .info (
122- f"Filtered { len (filtered_actions )} actions out of { len (all_actions )} for task { self .task .dataset } "
123- )
124+ if isinstance (self .task , AbstractWebTask ):
125+ filtered_actions = self .task .filter_actions (all_actions )
126+ logger .info (
127+ f"Filtered { len (filtered_actions )} actions out of { len (all_actions )} for dataset { self .task .dataset } "
128+ )
129+ else :
130+ filtered_actions = all_actions
124131 final_step_action = ToolSpec .from_function (final_step )
125- filtered_actions .append (final_step_action )
126- return filtered_actions
132+ return filtered_actions + [final_step_action ]
127133
128134
129135@dataclass
@@ -135,12 +141,11 @@ class BrowserEnvArgs(AbstractEnvArgs):
135141
136142 def __init__ (
137143 self ,
138- task_name : str ,
139144 task : AbstractWebTask ,
140145 backend_cls : type [BrowserBackend ],
141146 task_seed : int = 0 ,
142147 ):
143- self .task_name = task_name
148+ self .task_name = f" { task . dataset } . { task . task_id } "
144149 self .task = task
145150 self .task_seed = task_seed
146151 self .backend_cls = backend_cls
0 commit comments