Skip to content

Commit 9c2e006

Browse files
fix: heterogenous loading and critic conflict handling (#3)
* hetero * ud * ud param * ud * ud * override * slice * ud * ud * ud
1 parent a27947f commit 9c2e006

17 files changed

Lines changed: 575 additions & 266 deletions

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Override any configuration value inline with `--override`:
5555
```bash
5656
python3 str_build/train/train_magrpo.py \
5757
--config str_build/configs/str_build_magrpo_config.yaml \
58-
--override model.name='Qwen/Qwen2.5-1.5B-Instruct' magrpo.num_turns=1
58+
--override agent_model.name='Qwen/Qwen2.5-1.5B-Instruct' magrpo.num_turns=1
5959
```
6060

6161
## Multi-Turn External Feedback
@@ -76,4 +76,3 @@ HouseBuild modes:
7676
- `rect_modification`
7777
- `resource_schedule`
7878
- `score_feedback`
79-

house_build/configs/house_build_iac_config.yaml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1-
model:
1+
agent_model:
22
name: Qwen/Qwen3-4B-Instruct-2507
33
type: qwen
44
temperature: 0.6
55
top_p: 0.6
66
max_length: 2048
77
dtype: bf16
88

9-
critic:
9+
agents: null
10+
11+
critic_model:
1012
name: "Qwen/Qwen3-4B-Instruct-2507"
1113
type: qwen
1214
temperature: 0.6
1315
top_p: 0.6
1416
max_length: 2048
1517
dtype: bf16
1618

19+
critics: null
1720

1821
dataset:
1922
name: house_build
2023
type: house_build
2124
json_path: ../dataset/data.json
22-
train_split: "[:]"
25+
train_split: "[:8]"
26+
eval_split: "[8:]"
2327

2428
output:
2529
base_dir: output
@@ -50,8 +54,8 @@ iac:
5054
use_separate_critic: true
5155
discount: 0.9
5256
early_termination_threshold: 0.0
53-
eval_interval: 0
54-
eval_num_samples: 1
57+
eval_interval: 10
58+
eval_num_samples: 2
5559
eval_batch_size: 1
5660
logging_steps: 40
5761

house_build/configs/house_build_maac_config.yaml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1-
model:
1+
agent_model:
22
name: Qwen/Qwen3-4B-Instruct-2507
33
type: qwen
44
temperature: 0.6
55
top_p: 0.6
66
max_length: 2048
77
dtype: bf16
88

9-
critic:
9+
agents: null
10+
11+
critic_model:
1012
name: "Qwen/Qwen3-4B-Instruct-2507"
1113
type: qwen
1214
temperature: 0.6
1315
top_p: 0.6
1416
max_length: 2048
1517
dtype: bf16
1618

19+
critics: null
1720

1821
dataset:
1922
name: house_build
2023
type: house_build
2124
json_path: ../dataset/data.json
22-
train_split: "[:]"
25+
train_split: "[:8]"
26+
eval_split: "[8:]"
2327

2428
output:
2529
base_dir: output
@@ -49,8 +53,8 @@ maac:
4953
top_k: null
5054
discount: 0.9
5155
early_termination_threshold: 0.0
52-
eval_interval: 0
53-
eval_num_samples: 1
56+
eval_interval: 10
57+
eval_num_samples: 2
5458
eval_batch_size: 1
5559
logging_steps: 40
5660

house_build/configs/house_build_magrpo_config.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1-
model:
1+
agent_model:
22
name: Qwen/Qwen3-4B-Instruct-2507
33
type: qwen
44
temperature: 0.6
55
top_p: 0.6
66
max_length: 2048
77
dtype: bf16
88

9+
agents: null
10+
11+
critic_model: null
12+
13+
critics: null
14+
915
dataset:
1016
name: house_build
1117
type: house_build
1218
json_path: ../dataset/data.json
13-
train_split: "[:]"
19+
train_split: "[:8]"
20+
eval_split: "[8:]"
1421

1522
output:
1623
base_dir: output
@@ -41,8 +48,8 @@ magrpo:
4148
rollout_buffer_size: 1
4249
train_batch_size: 1
4350
advantage_normalization: true
44-
eval_interval: 0
45-
eval_num_samples: 0
51+
eval_interval: 2
52+
eval_num_samples: 2
4653
eval_batch_size: 1
4754

4855
reward_processor:

house_build/external/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
from . import score_feedback
2020

2121

22-
# Verbose toggle for external previews
2322
VERBOSE = False
2423

2524

26-
# Context resolver API
2725
_context_resolver: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None
2826

2927

house_build/train/train_iac.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,38 @@ def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str,
5050
s = str(split_expr).strip()
5151
if not s:
5252
return items
53-
m = re.search(r"\[\s*(?P<start>-?\d*)\s*:\s*(?P<end>-?\d*)\s*\]", s)
53+
m = re.search(r"\[\s*(?P<start>-?[^:\]]*)\s*:\s*(?P<end>-?[^\]]*)\s*\]", s)
5454
if not m and ":" in s:
55-
m = re.match(r"\s*(?P<start>-?\d*)\s*:\s*(?P<end>-?\d*)\s*$", s)
55+
m = re.match(r"\s*(?P<start>-?[^:]*)\s*:\s*(?P<end>-?.*)\s*$", s)
5656
if not m:
5757
return items
58-
start_raw = m.group("start")
59-
end_raw = m.group("end")
60-
start = int(start_raw) if start_raw not in (None, "", "+") else None
61-
end = int(end_raw) if end_raw not in (None, "", "+") else None
62-
return items[slice(start, end)]
58+
start_raw = (m.group("start") or "").strip()
59+
end_raw = (m.group("end") or "").strip()
60+
total = len(items)
61+
62+
def _parse_index(raw: str):
63+
if raw in ("", "+"):
64+
return None
65+
if raw.endswith("%"):
66+
try:
67+
pct = float(raw[:-1].strip())
68+
except ValueError:
69+
return None
70+
return int(total * pct / 100.0)
71+
try:
72+
return int(raw)
73+
except ValueError:
74+
try:
75+
frac = float(raw)
76+
except ValueError:
77+
return None
78+
if 0 <= frac <= 1:
79+
return int(total * frac)
80+
return None
6381

82+
start = _parse_index(start_raw)
83+
end = _parse_index(end_raw)
84+
return items[slice(start, end)]
6485

6586
def _map_dtype(dtype_cfg: Any) -> Any:
6687
if isinstance(dtype_cfg, torch.dtype):
@@ -333,12 +354,11 @@ def main() -> int:
333354
for item in args.override:
334355
if item is None:
335356
continue
336-
for part in str(item).split(","):
337-
part = part.strip()
338-
if part:
339-
override_items.append(part)
357+
part = str(item).strip()
358+
if part:
359+
override_items.append(part)
340360
if override_items:
341-
cfg = apply_overrides(cfg, ",".join(override_items))
361+
cfg = apply_overrides(cfg, override_items)
342362
apply_prompt_defaults(cfg)
343363

344364
seed_val = cfg.get("seed", None)
@@ -387,24 +407,48 @@ def main() -> int:
387407
train_ds = Dataset.from_list(train_items)
388408
eval_ds = Dataset.from_list(eval_items) if eval_items else None
389409

390-
model_cfg = cfg.get("model") or {}
410+
model_cfg = cfg.get("agent_model") or {}
391411
if not isinstance(model_cfg, dict):
392412
model_cfg = {}
393-
critic_cfg = cfg.get("critic") or {}
394-
if not isinstance(critic_cfg, dict):
395-
critic_cfg = {}
413+
critic_model_cfg = cfg.get("critic_model") or {}
414+
if not isinstance(critic_model_cfg, dict):
415+
critic_model_cfg = {}
396416
model_name = str(model_cfg.get("name") or "")
397-
if not model_name:
398-
raise ValueError("model.name is required")
417+
agent_names = cfg.get("agents")
418+
if not model_name and not agent_names:
419+
raise ValueError("agent_model.name or agents is required")
420+
if agent_names is not None:
421+
if not isinstance(agent_names, (list, tuple)) or not all(
422+
isinstance(x, str) for x in agent_names
423+
):
424+
raise ValueError("agents must be a list of model names.")
425+
agent_names = [str(x) for x in agent_names]
426+
427+
critic_names = None
428+
critics_field = cfg.get("critics")
429+
if critics_field is not None:
430+
if not isinstance(critics_field, (list, tuple)) or not all(
431+
isinstance(x, str) for x in critics_field
432+
):
433+
raise ValueError("critics must be a list of model names.")
434+
critic_names = [str(x) for x in critics_field]
399435
model_kwargs: Dict[str, Any] = {}
400436

401437
dtype = _map_dtype(model_cfg.get("dtype") or model_cfg.get("torch_dtype"))
402438
if dtype is not None:
403439
model_kwargs["torch_dtype"] = dtype
404440

405-
tokenizer = AutoTokenizer.from_pretrained(model_name)
406-
if tokenizer.pad_token is None:
407-
tokenizer.pad_token = tokenizer.eos_token
441+
tokenizer_source = agent_names[0] if agent_names else model_name
442+
if not tokenizer_source:
443+
raise ValueError("agent_model.name or agents must be provided.")
444+
if agent_names:
445+
tokenizers = [AutoTokenizer.from_pretrained(name) for name in agent_names]
446+
else:
447+
tokenizers = [AutoTokenizer.from_pretrained(tokenizer_source)]
448+
for tok in tokenizers:
449+
if tok.pad_token is None:
450+
tok.pad_token = tok.eos_token
451+
tokenizer = tokenizers[0]
408452

409453
iac_args = get_iac_args(cfg, model_name=model_name)
410454
formatters = _build_formatters(cfg, num_agents=num_agents, tokenizer=tokenizer)
@@ -594,14 +638,15 @@ def reward_func(
594638
except Exception:
595639
is_multi_turn = False
596640
critic_model_kwargs: Dict[str, Any] = {}
597-
if isinstance(critic_cfg, dict):
598-
critic_dtype = _map_dtype(critic_cfg.get("dtype") or critic_cfg.get("torch_dtype"))
641+
if isinstance(critic_model_cfg, dict):
642+
critic_dtype = _map_dtype(
643+
critic_model_cfg.get("dtype") or critic_model_cfg.get("torch_dtype")
644+
)
599645
if critic_dtype is not None:
600646
critic_model_kwargs["torch_dtype"] = critic_dtype
601647

602648
trainer_kwargs: Dict[str, Any] = {
603-
"model": model_name,
604-
"tokenizer": tokenizer,
649+
"tokenizer": tokenizers if agent_names else tokenizer,
605650
"reward_func": reward_func,
606651
"formatters": formatters,
607652
"args": iac_args,
@@ -615,15 +660,14 @@ def reward_func(
615660
},
616661
"wandb_config": wandb_config,
617662
}
618-
critics = None
619-
if bool(getattr(iac_args, "use_separate_critic", True)):
620-
critic_name = str(critic_cfg.get("name") or "").strip()
621-
if not critic_name:
622-
raise ValueError("critic.name must be provided when use_separate_critic is true")
623-
num_agents_val = int(getattr(iac_args, "num_agents", 1))
624-
critics = [critic_name] * num_agents_val
625-
if critics is not None:
626-
trainer_kwargs["critics"] = critics
663+
trainer_kwargs["agent_model"] = model_name or None
664+
if agent_names:
665+
trainer_kwargs["agents"] = agent_names
666+
critic_name = str(critic_model_cfg.get("name") or "").strip() or None
667+
if critic_name:
668+
trainer_kwargs["critic_model"] = critic_name
669+
if critic_names:
670+
trainer_kwargs["critics"] = critic_names
627671
if reward_processor is not None:
628672
trainer_kwargs["reward_processor"] = reward_processor
629673

0 commit comments

Comments
 (0)