Skip to content

Commit b1ae20e

Browse files
committed
Add CoGames environment support
Adds minimal integration for CoGames (cogs-v-clips) environments: - environment.py: 31-line wrapper that strips package prefixes and calls cogames API - torch.py: Policy and Recurrent classes - cogames.ini: Config for machina_1.open_world mission - pyproject.toml: Add cogames extras with mettagrid dependencies - setup.py: Relax gymnasium/pettingzoo version constraints (>= vs ==)
1 parent 7a99b3b commit b1ae20e

6 files changed

Lines changed: 99 additions & 2 deletions

File tree

pufferlib/config/cogames.ini

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[base]
2+
package = cogames
3+
env_name = cogames.cogs-v-clips.machina_1.open_world
4+
policy_name = Policy
5+
rnn_name = Recurrent
6+
7+
[vec]
8+
num_envs = 64
9+
num_workers = 16
10+
batch_size = auto
11+
zero_copy = True
12+
13+
[env]
14+
render_mode = none
15+
variants = heart_chorus inventory_heart_tune
16+
17+
[train]
18+
total_timesteps = 50_000_000
19+
batch_size = auto
20+
minibatch_size = 1024
21+
bptt_horizon = 64
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""CoGames integration package."""
2+
3+
from .environment import env_creator, make
4+
5+
try:
6+
import torch
7+
from .torch import Policy, Recurrent
8+
except ImportError:
9+
pass
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""CoGames wrapper for PufferLib."""
2+
3+
import functools
4+
from cogames.cli.mission import get_mission
5+
from mettagrid import PufferMettaGridEnv
6+
from mettagrid.envs.stats_tracker import StatsTracker
7+
from mettagrid.simulator import Simulator
8+
from mettagrid.util.stats_writer import NoopStatsWriter
9+
10+
11+
def env_creator(name="cogames.cogs-v-clips"):
12+
return functools.partial(make, name=name)
13+
14+
15+
def make(name="cogames.cogs-v-clips.machina_1.open_world", variants=None, cogs=None, render_mode="auto", seed=None, buf=None):
16+
# Strip package prefixes
17+
parts = name.split(".")
18+
while parts and parts[0].replace("-", "_") in {"cogames", "cogs_v_clips"}:
19+
parts.pop(0)
20+
mission_name = ".".join(parts) if parts else "training_facility.harvest"
21+
22+
_, env_cfg, _ = get_mission(mission_name, variants_arg=variants, cogs=cogs)
23+
24+
render = "none" if render_mode == "auto" else "unicode" if render_mode in {"human", "ansi"} else render_mode
25+
simulator = Simulator()
26+
simulator.add_event_handler(StatsTracker(NoopStatsWriter()))
27+
env = PufferMettaGridEnv(simulator=simulator, cfg=env_cfg, buf=buf, seed=seed or 0)
28+
env.render_mode = render
29+
if seed:
30+
env.reset(seed)
31+
return env
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Torch policies for CoGames environments."""
2+
3+
import torch
4+
import pufferlib.models
5+
import pufferlib.pytorch
6+
7+
8+
class Policy(pufferlib.models.Default):
9+
def __init__(self, env, hidden_size: int = 256, **kwargs):
10+
super().__init__(env, hidden_size=hidden_size)
11+
self.register_buffer("_inv_scale", torch.tensor(1.0 / 255.0), persistent=False)
12+
13+
def encode_observations(self, observations, state=None):
14+
batch_size = observations.shape[0]
15+
if self.is_dict_obs:
16+
obs_map = pufferlib.pytorch.nativize_tensor(observations, self.dtype)
17+
flattened = torch.cat([v.view(batch_size, -1) for v in obs_map.values()], dim=1)
18+
else:
19+
flattened = observations.view(batch_size, -1).float() * self._inv_scale
20+
return self.encoder(flattened)
21+
22+
23+
class Recurrent(pufferlib.models.LSTMWrapper):
24+
def __init__(self, env, policy, input_size: int = 256, hidden_size: int = 256):
25+
super().__init__(env, policy, input_size=input_size, hidden_size=hidden_size)

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ metta = [
121121
'metta-mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=mettagrid',
122122
]
123123

124+
cogames = [
125+
'gym',
126+
'gymnasium',
127+
'omegaconf',
128+
'hydra-core',
129+
'duckdb',
130+
'raylib>=5.5.0',
131+
'mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=packages/mettagrid',
132+
'cogames @ git+https://github.com/metta-ai/metta.git@main#subdirectory=packages/cogames',
133+
]
134+
124135
microrts = [
125136
'gym==0.23',
126137
'gymnasium==0.29.1',

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ def run(self):
275275
'numpy<2.0',
276276
'shimmy[gym-v21]',
277277
'gym==0.23',
278-
'gymnasium==0.29.1',
279-
'pettingzoo==1.24.1',
278+
'gymnasium>=0.29.1',
279+
'pettingzoo>=1.24.1',
280280
]
281281

282282
if not NO_TRAIN:

0 commit comments

Comments
 (0)