Skip to content

Commit 2c2a850

Browse files
add changes from private repo
1 parent 2e87d58 commit 2c2a850

4 files changed

Lines changed: 210 additions & 152 deletions

File tree

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
__version__ = "0.3.2"
22

33
import inspect
4-
import numpy as np
4+
from logging import warning
55

6+
import numpy as np
67
from browsergym.core.registration import register_task
78

89
from .tasks.comp_building_block import CompositionalBuildingBlockTask
9-
from .tasks.dashboard import __TASKS__ as DASHBOARD_TASKS
10-
from .tasks.form import __TASKS__ as FORM_TASKS
11-
from .tasks.knowledge import __TASKS__ as KB_TASKS
12-
from .tasks.list import __TASKS__ as LIST_TASKS
13-
from .tasks.navigation import __TASKS__ as NAVIGATION_TASKS
14-
from .tasks.compositional.base import CompositionalTask
15-
from .tasks.compositional.update_task import __TASKS__ as UPDATE_TASKS
1610
from .tasks.compositional import (
11+
AGENT_CURRICULUM_L2,
12+
AGENT_CURRICULUM_L3,
1713
ALL_COMPOSITIONAL_TASKS,
1814
ALL_COMPOSITIONAL_TASKS_L2,
1915
ALL_COMPOSITIONAL_TASKS_L3,
20-
AGENT_CURRICULUM_L2,
21-
AGENT_CURRICULUM_L3,
2216
HUMAN_CURRICULUM_L2,
2317
HUMAN_CURRICULUM_L3,
2418
)
25-
from .tasks.compositional.base import HumanEvalTask
19+
from .tasks.compositional.base import CompositionalTask, HumanEvalTask
20+
from .tasks.compositional.update_task import __TASKS__ as UPDATE_TASKS
21+
from .tasks.dashboard import __TASKS__ as DASHBOARD_TASKS
22+
from .tasks.form import __TASKS__ as FORM_TASKS
23+
from .tasks.knowledge import __TASKS__ as KB_TASKS
24+
from .tasks.list import __TASKS__ as LIST_TASKS
25+
from .tasks.navigation import __TASKS__ as NAVIGATION_TASKS
2626
from .tasks.service_catalog import __TASKS__ as SERVICE_CATALOG_TASKS
2727
from .tasks.compositional.base import CompositionalTask
2828

@@ -53,8 +53,66 @@
5353
task,
5454
)
5555

56+
workarena_tasks_all = [task_class.get_task_id() for task_class in ALL_WORKARENA_TASKS]
57+
workarena_tasks_atomic = [task_class.get_task_id() for task_class in ATOMIC_TASKS]
58+
59+
TASK_CATEGORY_MAP = {
60+
"workarena.servicenow.all-menu": "menu",
61+
"workarena.servicenow.create-change-request": "form",
62+
"workarena.servicenow.create-hardware-asset": "form",
63+
"workarena.servicenow.create-incident": "form",
64+
"workarena.servicenow.create-problem": "form",
65+
"workarena.servicenow.create-user": "form",
66+
"workarena.servicenow.filter-asset-list": "list-filter",
67+
"workarena.servicenow.filter-change-request-list": "list-filter",
68+
"workarena.servicenow.filter-hardware-list": "list-filter",
69+
"workarena.servicenow.filter-incident-list": "list-filter",
70+
"workarena.servicenow.filter-service-catalog-item-list": "list-filter",
71+
"workarena.servicenow.filter-user-list": "list-filter",
72+
"workarena.servicenow.impersonation": "menu",
73+
"workarena.servicenow.knowledge-base-search": "knowledge",
74+
"workarena.servicenow.order-apple-mac-book-pro15": "service catalog",
75+
"workarena.servicenow.order-apple-watch": "service catalog",
76+
"workarena.servicenow.order-developer-laptop": "service catalog",
77+
"workarena.servicenow.order-development-laptop-p-c": "service catalog",
78+
"workarena.servicenow.order-ipad-mini": "service catalog",
79+
"workarena.servicenow.order-ipad-pro": "service catalog",
80+
"workarena.servicenow.order-loaner-laptop": "service catalog",
81+
"workarena.servicenow.order-sales-laptop": "service catalog",
82+
"workarena.servicenow.order-standard-laptop": "service catalog",
83+
"workarena.servicenow.sort-asset-list": "list-sort",
84+
"workarena.servicenow.sort-change-request-list": "list-sort",
85+
"workarena.servicenow.sort-hardware-list": "list-sort",
86+
"workarena.servicenow.sort-incident-list": "list-sort",
87+
"workarena.servicenow.sort-service-catalog-item-list": "list-sort",
88+
"workarena.servicenow.sort-user-list": "list-sort",
89+
"workarena.servicenow.dashboard-min-max-retrieval": "dashboard",
90+
"workarena.servicenow.dashboard-value-retrieval": "dashboard",
91+
"workarena.servicenow.report-value-retrieval": "dashboard",
92+
"workarena.servicenow.report-min-max-retrieval": "dashboard",
93+
}
94+
95+
96+
workarena_tasks_l1 = list(TASK_CATEGORY_MAP.keys())
97+
workarena_task_categories = {}
98+
for task in workarena_tasks_atomic:
99+
if task not in TASK_CATEGORY_MAP:
100+
warning(f"Atomic task {task} not found in TASK_CATEGORY_MAP")
101+
continue
102+
cat = TASK_CATEGORY_MAP[task]
103+
if cat in workarena_task_categories:
104+
workarena_task_categories[cat].append(task)
105+
else:
106+
workarena_task_categories[cat] = [task]
107+
56108

57-
def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10):
109+
def get_task_category(task_name):
110+
benchmark = task_name.split(".")[0]
111+
return benchmark, TASK_CATEGORY_MAP.get(task_name, None)
112+
113+
114+
def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10, is_agent_curriculum=True):
115+
OFFSET = 42
58116
all_task_tuples = []
59117
filter = filter.split(".")
60118
if len(filter) > 2:
@@ -79,60 +137,25 @@ def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10):
79137
else:
80138
filter_category = None
81139

82-
if level == "l2":
83-
ALL_COMPOSITIONAL_TASKS_CATEGORIES = AGENT_CURRICULUM_L2
84-
else:
85-
ALL_COMPOSITIONAL_TASKS_CATEGORIES = AGENT_CURRICULUM_L3
86-
87-
for category, items in ALL_COMPOSITIONAL_TASKS_CATEGORIES.items():
88-
if filter_category and category != filter_category:
89-
continue
90-
for curr_seed in rng.randint(0, 1000, items["num_seeds"]):
91-
random_gen = np.random.RandomState(curr_seed)
92-
for task_set, count in zip(items["buckets"], items["weights"]):
93-
tasks = random_gen.choice(task_set, count, replace=False)
94-
for task in tasks:
95-
all_task_tuples.append((task, curr_seed))
96-
97-
return all_task_tuples
98-
99-
100-
def get_all_tasks_humans(filter="l2", meta_seed=42):
101-
OFFSET = 42
102-
all_task_tuples = []
103-
filter = filter.split(".")
104-
if len(filter) > 2:
105-
raise Exception("Unsupported filter used.")
106-
if len(filter) == 1:
107-
level = filter[0]
108-
if level not in ["l1", "l2", "l3"]:
109-
raise Exception("Unsupported category of tasks.")
140+
if is_agent_curriculum:
141+
if level == "l2":
142+
ALL_COMPOSITIONAL_TASKS_CATEGORIES = AGENT_CURRICULUM_L2
110143
else:
111-
rng = np.random.RandomState(meta_seed)
112-
if level == "l1":
113-
return [(task, rng.randint(0, 1000)) for task in ATOMIC_TASKS]
114-
115-
if len(filter) == 2:
116-
level, filter_category = filter[0], filter[1]
117-
if filter_category not in list(HUMAN_CURRICULUM_L2.keys()):
118-
raise Exception("Unsupported category of tasks.")
144+
ALL_COMPOSITIONAL_TASKS_CATEGORIES = AGENT_CURRICULUM_L3
119145
else:
120-
filter_category = None
121-
122-
if level == "l2":
123-
ALL_COMPOSITIONAL_TASKS_CATEGORIES = HUMAN_CURRICULUM_L2
124-
else:
125-
ALL_COMPOSITIONAL_TASKS_CATEGORIES = HUMAN_CURRICULUM_L3
146+
if level == "l2":
147+
ALL_COMPOSITIONAL_TASKS_CATEGORIES = HUMAN_CURRICULUM_L2
148+
else:
149+
ALL_COMPOSITIONAL_TASKS_CATEGORIES = HUMAN_CURRICULUM_L3
126150

127151
for category, items in ALL_COMPOSITIONAL_TASKS_CATEGORIES.items():
128152
if filter_category and category != filter_category:
129153
continue
130-
# We will come back to this after the submission
131154
for curr_seed in rng.randint(0, 1000, items["num_seeds"]):
132155
random_gen = np.random.RandomState(curr_seed)
133156
for task_set, count in zip(items["buckets"], items["weights"]):
134157
tasks = random_gen.choice(task_set, count, replace=False)
135158
for task in tasks:
136-
all_task_tuples.append((task, curr_seed))
159+
all_task_tuples.append((task, int(curr_seed)))
137160

138-
return all_task_tuples
161+
return all_task_tuples

src/browsergym/workarena/tasks/dashboard.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import numpy as np
44
import playwright.sync_api
55
import re
6-
import tenacity
76

87
from abc import ABC, abstractmethod
8+
from tenacity import retry, stop_after_attempt, wait_fixed
99
from typing import List, Tuple
1010
from urllib import parse
1111

@@ -179,6 +179,8 @@ def _read_chart(self, page: playwright.sync_api.Page, element_id: str) -> str:
179179

180180
return type, data
181181

182+
# retry because sometimes the page is not fully loaded
183+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
182184
def _get_chart_by_title(
183185
self, page: playwright.sync_api.Page, title: str = None
184186
) -> Tuple[str, dict]:
@@ -799,4 +801,4 @@ def setup_goal(self, page: playwright.sync_api.Page) -> Tuple[str | dict]:
799801
and issubclass(var, DashboardRetrievalTask)
800802
and not issubclass(var, CompositionalBuildingBlockTask)
801803
and var is not DashboardRetrievalTask
802-
]
804+
]

0 commit comments

Comments
 (0)