Skip to content

Commit 9748ec3

Browse files
update TODO and black refactor
1 parent d36709a commit 9748ec3

2 files changed

Lines changed: 24 additions & 19 deletions

File tree

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,10 @@ def get_action(self, obs: Any) -> float:
574574
use_som=False,
575575
use_tabs=False,
576576
),
577-
summarizer=Summarizer(do_summary=True), # do not summarize in OSWorld
577+
summarizer=Summarizer(do_summary=True),
578578
general_hints=GeneralHints(use_hints=False),
579579
task_hint=TaskHint(use_task_hint=False),
580-
keep_last_n_obs=None, # keep only the last observation in the discussion
580+
keep_last_n_obs=None,
581581
multiaction=False, # whether to use multi-action or not
582582
action_subsets=("coord",), # or "bid"
583583
),
@@ -604,7 +604,5 @@ def get_action(self, obs: Any) -> float:
604604
multiaction=False, # whether to use multi-action or not
605605
action_subsets=("coord",),
606606
),
607-
action_set=OSWorldActionSet(
608-
"computer_13"
609-
), # or "pyautogui" #TODO: agent config should only be some primitive types.
607+
action_set=OSWorldActionSet("computer_13"),
610608
)

src/agentlab/benchmarks/osworld.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
logger = logging.getLogger(__name__)
3131

32-
# TODO: Extract X_Max and Y_MAX from screen size
3332
COMPUTER_13_ACTIONS_OAI_CHATCOMPLETION_TOOLS = [
3433
{
3534
"type": "function",
@@ -549,10 +548,10 @@ def close(self):
549548

550549
def format_chat_completion_tools_to_response_api(tools: list[dict]) -> list[dict]:
551550
"""Convert tools from OpenAI Chat Completion format to Responses API format.
552-
551+
553552
Args:
554553
tools: List of tools in Chat Completion format with nested function object
555-
554+
556555
Returns:
557556
List of tools in Responses API format with flattened structure
558557
"""
@@ -565,15 +564,16 @@ def format_chat_completion_tools_to_response_api(tools: list[dict]) -> list[dict
565564
"description": function_def["description"],
566565
"parameters": function_def["parameters"],
567566
}
568-
567+
569568
# Handle the strict field if present
570569
if "strict" in function_def:
571570
formatted_tool["strict"] = function_def["strict"]
572-
571+
573572
formatted_tools.append(formatted_tool)
574573

575574
return formatted_tools
576575

576+
577577
@dataclass
578578
class OSWorldActionSet(AbstractActionSet, DataClassJsonMixin):
579579
# TODO: Define and use agentlab AbstractActionSet
@@ -598,11 +598,23 @@ def to_python_code(self, action) -> str:
598598

599599
def to_tool_description(self, api="openai"):
600600
"""Convert the action set to a tool description for Tool-Use LLMs.
601+
601602
The default for openai is openai Response API tools format.
603+
604+
Args:
605+
api (str): The API format to use. Defaults to "openai".
606+
607+
Returns:
608+
list[dict]: List of tool descriptions in the specified API format.
609+
610+
Raises:
611+
ValueError: If an unsupported action space is specified.
602612
"""
603-
# TODO: Rename bgym AbstractActionSet to_tool_descriptor method as to_tool_description for consistency.
613+
# TODO: Rename bgym AbstractActionSet 'to_tool_descriptor' method as 'to_tool_description' for consistency.
604614
if self.action_space == "computer_13":
605-
tools = format_chat_completion_tools_to_response_api(COMPUTER_13_ACTIONS_OAI_CHATCOMPLETION_TOOLS)
615+
tools = format_chat_completion_tools_to_response_api(
616+
COMPUTER_13_ACTIONS_OAI_CHATCOMPLETION_TOOLS
617+
)
606618
else:
607619
raise ValueError(
608620
"Only 'computer_13' action space is currently supported for tool description."
@@ -642,7 +654,7 @@ class OsworldEnvArgs(AbstractEnvArgs):
642654
task: dict[str, Any]
643655
task_seed: int = 0
644656
task_name: str | None = None
645-
path_to_vm: str | None = None # path to .vmx file
657+
path_to_vm: str | None = None # path to .vmx file
646658
provider_name: str = "docker" # path to .vmx file
647659
region: str = "us-east-1" # AWS specific, does not apply to all providers
648660
snapshot_name: str = "init_state" # snapshot name to revert to
@@ -694,9 +706,7 @@ def model_post_init(self, __context: Any) -> None:
694706
self.env_args_list = []
695707
if not self.env_args:
696708
self.env_args = OsworldEnvArgs(task={})
697-
self.high_level_action_set_args = OSWorldActionSet(
698-
action_space=self.env_args.action_space
699-
)
709+
self.high_level_action_set_args = OSWorldActionSet(action_space=self.env_args.action_space)
700710
with open(os.path.join(self.test_set_path, self.test_set_name)) as f:
701711
tasks = json.load(f)
702712
if self.domain != "all":
@@ -717,9 +727,6 @@ def model_post_init(self, __context: Any) -> None:
717727

718728
def fix_settings_file_path_in_config(self, task: dict) -> dict:
719729
"""Fix the settings file path in the task configuration.
720-
721-
#TODO: We can create our own tiny_osworld.json with correct paths (OSWorld prefixed) to settings files. Meanwhile use this function to fix the paths.
722-
723730
Args:
724731
task (dict): Task configuration dictionary.
725732

0 commit comments

Comments
 (0)