Skip to content

Commit f10615f

Browse files
committed
better task interface, support old bgym tasks in the new env
1 parent e28eb0f commit f10615f

5 files changed

Lines changed: 237 additions & 98 deletions

File tree

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from agentlab.backends.browser.base import BrowserBackend
1+
from agentlab.backends.browser.base import AsyncBrowserBackend, BrowserBackend
22
from agentlab.backends.browser.env import BrowserEnv, BrowserEnvArgs
33
from agentlab.backends.browser.mcp import MCPBrowserBackend, MCPClient
44
from agentlab.backends.browser.mcp_playwright import MCPPlaywright
5-
from agentlab.backends.browser.playwright import AsyncPlaywright
5+
from agentlab.backends.browser.playwright import AsyncPlaywright, SyncPlaywright
66

77
__all__ = [
88
"BrowserBackend",
9+
"AsyncBrowserBackend",
910
"BrowserEnv",
1011
"BrowserEnvArgs",
1112
"MCPBrowserBackend",
1213
"MCPClient",
1314
"MCPPlaywright",
1415
"AsyncPlaywright",
16+
"SyncPlaywright",
1517
]

src/agentlab/backends/browser/base.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111

1212
class BrowserBackend(BaseModel, ABC):
13+
has_pw_page: bool = False
14+
1315
@abstractmethod
1416
def initialize(self) -> None:
1517
pass
1618

1719
@abstractmethod
18-
def run_js(self, js: str):
19-
pass
20+
def evaluate_js(self, js: str) -> str | dict | list:
21+
return ""
2022

2123
@abstractmethod
2224
def goto(self, url: str) -> str:
@@ -27,7 +29,7 @@ def page_html(self) -> str:
2729
pass
2830

2931
@abstractmethod
30-
def page_screenshot(self) -> Image:
32+
def page_screenshot(self) -> Image.Image:
3133
pass
3234

3335
@abstractmethod
@@ -39,9 +41,62 @@ def step(self, action: ToolCall) -> dict:
3941
pass
4042

4143
@abstractmethod
42-
def actions(self) -> tuple[ToolSpec]:
44+
def actions(self) -> list[ToolSpec]:
4345
pass
4446

4547
@abstractmethod
4648
def close(self) -> None:
4749
pass
50+
51+
@property
52+
def page(self):
53+
raise NotImplementedError("Direct access to the playwright page is not supported.")
54+
55+
56+
class AsyncBrowserBackend(BaseModel):
57+
"""Abstract base class for async browser backends."""
58+
59+
has_pw_page: bool = False
60+
61+
class Config:
62+
arbitrary_types_allowed = True
63+
64+
@abstractmethod
65+
async def initialize(self) -> None:
66+
pass
67+
68+
@abstractmethod
69+
async def evaluate_js(self, js: str) -> str | dict | list:
70+
pass
71+
72+
@abstractmethod
73+
async def goto(self, url: str) -> None:
74+
pass
75+
76+
@abstractmethod
77+
async def page_html(self) -> str:
78+
pass
79+
80+
@abstractmethod
81+
async def page_screenshot(self) -> Image.Image:
82+
pass
83+
84+
@abstractmethod
85+
async def page_axtree(self) -> str:
86+
pass
87+
88+
@abstractmethod
89+
async def step(self, action: ToolCall) -> dict:
90+
pass
91+
92+
@abstractmethod
93+
def actions(self) -> list[ToolSpec]:
94+
pass
95+
96+
@abstractmethod
97+
async def close(self) -> None:
98+
pass
99+
100+
@property
101+
def page(self):
102+
raise NotImplementedError("Direct access to the playwright page is not supported.")

src/agentlab/backends/browser/env.py

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6+
from browsergym.core.task import AbstractBrowserTask
7+
68
from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec
79
from agentlab.backends.browser.base import BrowserBackend
810
from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs
@@ -15,30 +17,42 @@ def final_step():
1517
"""
1618
Finish the task execution.
1719
"""
18-
pass
20+
return {
21+
"pruned_html": "Task finished",
22+
"axtree_txt": "",
23+
"last_action_error": "",
24+
"focused_element_bid": "none",
25+
}
1926

2027

2128
class BrowserEnv(AbstractEnv):
2229
def __init__(
23-
self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0
30+
self, task_name: str, task: AbstractWebTask | AbstractBrowserTask, backend: BrowserBackend, seed: int = 0
2431
):
2532
self.task_name = task_name
2633
self.task = task
2734
self.seed = seed
2835
self._turns = 0
29-
self.max_turns = task.max_turns
3036
self.backend = backend
3137
self.backend.initialize()
3238
self.goal = ""
39+
if isinstance(self.task, AbstractBrowserTask) and not self.backend.has_pw_page:
40+
raise ValueError(
41+
"Legacy task requires a backend with direct playwright page access."
42+
)
3343

3444
def reset(self, seed: int):
3545
self.seed = seed
36-
logger.info(f"Open task URL: {self.task.url}")
37-
self.backend.goto(self.task.url)
38-
setup_js = self.task.get_setup_js()
39-
if setup_js:
40-
self.goal = self.task.parse_setup_result(self.backend.run_js(setup_js))
41-
logger.info(f"Task goal: {self.goal}")
46+
if isinstance(self.task, AbstractBrowserTask):
47+
self.goal, task_info = self.task.setup(page=self.backend.page)
48+
obs = self._get_obs()
49+
else:
50+
self.goal, task_info = self.task.setup(backend=self.backend)
51+
obs = self._get_obs()
52+
obs = self.task.obs_postprocess(obs)
53+
return obs, task_info
54+
55+
def _get_obs(self) -> dict:
4256
html = self.backend.page_html()
4357
screenshot = self.backend.page_screenshot()
4458
axtree = self.backend.page_axtree()
@@ -50,80 +64,72 @@ def reset(self, seed: int):
5064
"last_action_error": "",
5165
"focused_element_bid": "none",
5266
}
53-
obs = self.task.obs_postprocess(obs)
54-
return obs, {}
67+
return obs
5568

5669
def step(self, action: ToolCall | str) -> tuple[dict, float, bool, bool, dict]:
5770
if isinstance(action, str):
5871
action = ToolsActionSet.parse_action(action)
5972
logger.info(f"BrowserEnv.step() called with action {action}")
6073

6174
action_exec_start = time.time()
62-
finished = action.name == "final_step"
63-
if finished:
64-
observation = {
65-
"goal_object": [{"type": "text", "text": self.goal}],
66-
"pruned_html": "Task finished",
67-
"axtree_txt": "",
68-
"last_action_error": "",
69-
"focused_element_bid": "none",
70-
}
75+
done = action.name == "final_step"
76+
if done:
77+
observation = final_step()
7178
else:
72-
observation = self._step(action)
73-
observation = self.task.obs_postprocess(observation)
74-
79+
observation = self.backend.step(action)
7580
action_exec_stop = time.time()
7681
self._turns += 1
77-
truncated = self._turns >= self.max_turns
82+
if isinstance(self.task, AbstractWebTask):
83+
truncated = self._turns >= self.task.max_turns
84+
else:
85+
truncated = False
7886

79-
if self.task.validate_per_step or finished or truncated:
80-
reward, other = self.validate_task(action, observation)
81-
if other.get("done", False):
82-
finished = True
87+
observation = self.obs_postprocess(observation)
88+
89+
if isinstance(self.task, AbstractBrowserTask):
90+
reward, done, _, info = self.task.validate(page=self.backend.page, chat_messages=[])
91+
elif self.task.validate_per_step or done or truncated:
92+
reward, info = self.task.validate()
93+
if info.get("done", False):
94+
done = True
8395
else:
8496
reward = 0.0
85-
other = {}
97+
info = {}
8698

8799
env_info = {
100+
**info,
88101
"action_exec_start": action_exec_start,
89102
"action_exec_stop": action_exec_stop,
90-
"action_exec_timeout": 0.0,
91-
} | other
103+
"action_exec_timeout": 0.0
104+
}
92105
logger.info(f"Action result in observation: {observation}")
93-
return observation, reward, finished, truncated, env_info
94-
95-
def _step(self, action: ToolCall) -> dict:
96-
obs_dict = self.backend.step(action)
97-
if "goal_object" not in obs_dict:
98-
obs_dict["goal_object"] = [{"type": "text", "text": self.goal}]
99-
if "last_action_error" not in obs_dict:
100-
obs_dict["last_action_error"] = ""
101-
if "focused_element_bid" not in obs_dict:
102-
obs_dict["focused_element_bid"] = "none"
103-
return obs_dict
104-
105-
def validate_task(self, action: ToolCall, observation: dict) -> tuple[float, dict]:
106-
validate_js = self.task.get_step_validate_js()
107-
validate_result = self.backend.run_js(validate_js)
108-
reward, other = self.task.parse_validation_result(validate_result)
109-
return reward, other
106+
return observation, reward, done, truncated, env_info
107+
108+
def obs_postprocess(self, obs: dict) -> dict:
109+
if "goal_object" not in obs:
110+
obs["goal_object"] = [{"type": "text", "text": self.goal}]
111+
if "last_action_error" not in obs:
112+
obs["last_action_error"] = ""
113+
if "focused_element_bid" not in obs:
114+
obs["focused_element_bid"] = "none"
115+
if isinstance(self.task, AbstractWebTask):
116+
obs = self.task.obs_postprocess(obs)
117+
return obs
110118

111119
def close(self):
112-
teardown_js = self.task.get_teardown_js()
113-
if teardown_js:
114-
js_result_str = self.backend.run_js(teardown_js)
115-
logger.info(f"Task teardown result: {js_result_str}")
116-
self.backend.close()
120+
self.task.teardown()
117121

118122
def actions(self) -> list[ToolSpec]:
119123
all_actions = self.backend.actions()
120-
filtered_actions = self.task.filter_actions(all_actions)
121-
logger.info(
122-
f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for task {self.task.dataset}"
123-
)
124+
if isinstance(self.task, AbstractWebTask):
125+
filtered_actions = self.task.filter_actions(all_actions)
126+
logger.info(
127+
f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for dataset {self.task.dataset}"
128+
)
129+
else:
130+
filtered_actions = all_actions
124131
final_step_action = ToolSpec.from_function(final_step)
125-
filtered_actions.append(final_step_action)
126-
return filtered_actions
132+
return filtered_actions + [final_step_action]
127133

128134

129135
@dataclass
@@ -135,12 +141,11 @@ class BrowserEnvArgs(AbstractEnvArgs):
135141

136142
def __init__(
137143
self,
138-
task_name: str,
139144
task: AbstractWebTask,
140145
backend_cls: type[BrowserBackend],
141146
task_seed: int = 0,
142147
):
143-
self.task_name = task_name
148+
self.task_name = f"{task.dataset}.{task.task_id}"
144149
self.task = task
145150
self.task_seed = task_seed
146151
self.backend_cls = backend_cls

0 commit comments

Comments
 (0)