Skip to content

Commit cde72b2

Browse files
committed
switchin _agents_on_benchmark to Study method for flexibility
1 parent 67f186d commit cde72b2

1 file changed

Lines changed: 82 additions & 81 deletions

File tree

src/agentlab/experiments/study.py

Lines changed: 82 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from concurrent.futures import ProcessPoolExecutor
21
import gzip
32
import logging
43
import os
54
import pickle
5+
import random
66
import uuid
77
from abc import ABC, abstractmethod
8+
from concurrent.futures import ProcessPoolExecutor
89
from dataclasses import dataclass
910
from datetime import datetime
11+
from multiprocessing import Manager, Pool, Queue
1012
from pathlib import Path
1113

1214
import bgym
@@ -17,10 +19,9 @@
1719
from agentlab.analyze import inspect_results
1820
from agentlab.experiments import reproducibility_util as repro
1921
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
20-
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
22+
from agentlab.experiments.launch_exp import (find_incomplete, non_dummy_count,
23+
run_experiments)
2124
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
22-
from multiprocessing import Pool, Manager, Queue
23-
import random
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -238,7 +239,7 @@ def __post_init__(self):
238239

239240
def make_exp_args_list(self):
240241
"""Generate the exp_args_list from the agent_args and the benchmark."""
241-
self.exp_args_list = _agents_on_benchmark(
242+
self.exp_args_list = self.agents_on_benchmark(
242243
self.agent_args,
243244
self.benchmark,
244245
logging_level=self.logging_level,
@@ -424,6 +425,82 @@ def load(dir: Path) -> "Study":
424425
def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
425426
return Study.load(get_most_recent_study(root_dir, contains=contains))
426427

428+
def agents_on_benchmark(
429+
self,
430+
agents: list[AgentArgs] | AgentArgs,
431+
benchmark: bgym.Benchmark,
432+
demo_mode=False,
433+
logging_level: int = logging.INFO,
434+
logging_level_stdout: int = logging.INFO,
435+
ignore_dependencies=False,
436+
):
437+
"""Run one or multiple agents on a benchmark.
438+
439+
Args:
440+
agents: list[AgentArgs] | AgentArgs
441+
The agent configuration(s) to run.
442+
benchmark: bgym.Benchmark
443+
The benchmark to run the agents on.
444+
demo_mode: bool
445+
If True, the experiments will be run in demo mode.
446+
logging_level: int
447+
The logging level for individual jobs.
448+
logging_level_stdout: int
449+
The logging level for the stdout.
450+
ignore_dependencies: bool
451+
If True, the dependencies will be ignored and all experiments can be run in parallel.
452+
453+
Returns:
454+
list[ExpArgs]: The list of experiments to run.
455+
456+
Raises:
457+
ValueError: If multiple agents are run on a benchmark that requires manual reset.
458+
"""
459+
460+
if not isinstance(agents, (list, tuple)):
461+
agents = [agents]
462+
463+
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
464+
if len(agents) > 1:
465+
raise ValueError(
466+
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
467+
)
468+
469+
for agent in agents:
470+
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark
471+
472+
env_args_list = benchmark.env_args_list
473+
if demo_mode:
474+
set_demo_mode(env_args_list)
475+
476+
exp_args_list = []
477+
478+
for agent in agents:
479+
for env_args in env_args_list:
480+
exp_args = ExpArgs(
481+
agent_args=agent,
482+
env_args=env_args,
483+
logging_level=logging_level,
484+
logging_level_stdout=logging_level_stdout,
485+
)
486+
exp_args_list.append(exp_args)
487+
488+
for i, exp_args in enumerate(exp_args_list):
489+
exp_args.order = i
490+
491+
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
492+
# _flag_sequential_exp(exp_args_list, benchmark)
493+
494+
if not ignore_dependencies:
495+
# populate the depends_on field based on the task dependencies in the benchmark
496+
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
497+
else:
498+
logger.warning(
499+
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
500+
)
501+
502+
return exp_args_list
503+
427504

428505
def _make_study_name(agent_names, benchmark_names, suffix=None):
429506
"""Make a study name from the agent and benchmark names."""
@@ -634,82 +711,6 @@ def set_demo_mode(env_args_list: list[EnvArgs]):
634711
env_args.slow_mo = 1000
635712

636713

637-
def _agents_on_benchmark(
638-
agents: list[AgentArgs] | AgentArgs,
639-
benchmark: bgym.Benchmark,
640-
demo_mode=False,
641-
logging_level: int = logging.INFO,
642-
logging_level_stdout: int = logging.INFO,
643-
ignore_dependencies=False,
644-
):
645-
"""Run one or multiple agents on a benchmark.
646-
647-
Args:
648-
agents: list[AgentArgs] | AgentArgs
649-
The agent configuration(s) to run.
650-
benchmark: bgym.Benchmark
651-
The benchmark to run the agents on.
652-
demo_mode: bool
653-
If True, the experiments will be run in demo mode.
654-
logging_level: int
655-
The logging level for individual jobs.
656-
logging_level_stdout: int
657-
The logging level for the stdout.
658-
ignore_dependencies: bool
659-
If True, the dependencies will be ignored and all experiments can be run in parallel.
660-
661-
Returns:
662-
list[ExpArgs]: The list of experiments to run.
663-
664-
Raises:
665-
ValueError: If multiple agents are run on a benchmark that requires manual reset.
666-
"""
667-
668-
if not isinstance(agents, (list, tuple)):
669-
agents = [agents]
670-
671-
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
672-
if len(agents) > 1:
673-
raise ValueError(
674-
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
675-
)
676-
677-
for agent in agents:
678-
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark
679-
680-
env_args_list = benchmark.env_args_list
681-
if demo_mode:
682-
set_demo_mode(env_args_list)
683-
684-
exp_args_list = []
685-
686-
for agent in agents:
687-
for env_args in env_args_list:
688-
exp_args = ExpArgs(
689-
agent_args=agent,
690-
env_args=env_args,
691-
logging_level=logging_level,
692-
logging_level_stdout=logging_level_stdout,
693-
)
694-
exp_args_list.append(exp_args)
695-
696-
for i, exp_args in enumerate(exp_args_list):
697-
exp_args.order = i
698-
699-
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
700-
# _flag_sequential_exp(exp_args_list, benchmark)
701-
702-
if not ignore_dependencies:
703-
# populate the depends_on field based on the task dependencies in the benchmark
704-
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
705-
else:
706-
logger.warning(
707-
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
708-
)
709-
710-
return exp_args_list
711-
712-
713714
# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
714715
# if benchmark.name.startswith("visualwebarena"):
715716
# sequential_subset = benchmark.subset_from_glob("requires_reset", "True")

0 commit comments

Comments
 (0)