Skip to content

Commit 8f88d9f

Browse files
authored
Merge pull request #400 from relh/richard-cogames-train
Add cogames.cogs-v-clips install/training command
2 parents 5421f8e + 21d3330 commit 8f88d9f

6 files changed

Lines changed: 95 additions & 2 deletions

File tree

pufferlib/config/cogames.ini

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[base]
2+
package = cogames
3+
env_name = cogames.cogs_v_clips.training_facility.harvest cogames.cogs_v_clips.training_facility.assemble cogames.cogs_v_clips.machina_1.open_world
4+
policy_name = Policy
5+
rnn_name = Recurrent
6+
7+
[vec]
8+
num_envs = 64
9+
num_workers = 16
10+
batch_size = auto
11+
zero_copy = True
12+
13+
[env]
14+
render_mode = none
15+
variants = heart_chorus inventory_heart_tune
16+
17+
[train]
18+
total_timesteps = 50_000_000
19+
batch_size = auto
20+
minibatch_size = 1024
21+
bptt_horizon = 64
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""CoGames integration package."""
2+
3+
from .environment import env_creator, make
4+
5+
try:
6+
import torch
7+
from .torch import Policy, Recurrent
8+
except ImportError:
9+
pass
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""CoGames wrapper for PufferLib."""
2+
3+
import functools
4+
from cogames.cli.mission import get_mission
5+
from mettagrid import PufferMettaGridEnv
6+
from mettagrid.envs.stats_tracker import StatsTracker
7+
from mettagrid.simulator import Simulator
8+
from mettagrid.util.stats_writer import NoopStatsWriter
9+
10+
11+
def env_creator(name="cogames.cogs_v_clips.machina_1.open_world"):
12+
return functools.partial(make, name=name)
13+
14+
15+
def make(name="cogames.cogs_v_clips.machina_1.open_world", variants=None, cogs=None, render_mode="auto", seed=None, buf=None):
16+
mission_name = name.removeprefix("cogames.cogs_v_clips.") if name.startswith("cogames.cogs_v_clips.") else name
17+
variants = variants.split() if isinstance(variants, str) else variants
18+
_, env_cfg, _ = get_mission(mission_name, variants_arg=variants, cogs=cogs)
19+
20+
render = "none" if render_mode == "auto" else "unicode" if render_mode in {"human", "ansi"} else render_mode
21+
simulator = Simulator()
22+
simulator.add_event_handler(StatsTracker(NoopStatsWriter()))
23+
env = PufferMettaGridEnv(simulator=simulator, cfg=env_cfg, buf=buf, seed=seed or 0)
24+
env.render_mode = render
25+
if seed:
26+
env.reset(seed)
27+
return env
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Torch policies for CoGames environments."""
2+
3+
import torch
4+
import pufferlib.models
5+
import pufferlib.pytorch
6+
7+
8+
class Policy(pufferlib.models.Default):
9+
def __init__(self, env, hidden_size: int = 256, **kwargs):
10+
super().__init__(env, hidden_size=hidden_size)
11+
self.register_buffer("_inv_scale", torch.tensor(1.0 / 255.0), persistent=False)
12+
13+
def encode_observations(self, observations, state=None):
14+
batch_size = observations.shape[0]
15+
if self.is_dict_obs:
16+
obs_map = pufferlib.pytorch.nativize_tensor(observations, self.dtype)
17+
flattened = torch.cat([v.view(batch_size, -1) for v in obs_map.values()], dim=1)
18+
else:
19+
flattened = observations.view(batch_size, -1).float() * self._inv_scale
20+
return self.encoder(flattened)
21+
22+
23+
class Recurrent(pufferlib.models.LSTMWrapper):
24+
def __init__(self, env, policy, input_size: int = 256, hidden_size: int = 256):
25+
super().__init__(env, policy, input_size=input_size, hidden_size=hidden_size)

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ metta = [
121121
'metta-mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=mettagrid',
122122
]
123123

124+
cogames = [
125+
'gym',
126+
'gymnasium',
127+
'omegaconf',
128+
'hydra-core',
129+
'duckdb',
130+
'raylib>=5.5.0',
131+
'mettagrid @ git+https://github.com/metta-ai/metta.git@main#subdirectory=packages/mettagrid',
132+
'cogames @ git+https://github.com/metta-ai/metta.git@main#subdirectory=packages/cogames',
133+
]
134+
124135
microrts = [
125136
'gym==0.23',
126137
'gymnasium==0.29.1',

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ def run(self):
275275
'numpy<2.0',
276276
'shimmy[gym-v21]',
277277
'gym==0.23',
278-
'gymnasium==0.29.1',
279-
'pettingzoo==1.24.1',
278+
'gymnasium>=0.29.1',
279+
'pettingzoo>=1.24.1',
280280
]
281281

282282
if not NO_TRAIN:

0 commit comments

Comments
 (0)