@@ -50,17 +50,38 @@ def _slice_items(items: List[Dict[str, Any]], split_expr: Any) -> List[Dict[str,
5050 s = str (split_expr ).strip ()
5151 if not s :
5252 return items
53- m = re .search (r"\[\s*(?P<start>-?\d *)\s*:\s*(?P<end>-?\d *)\s*\]" , s )
53+ m = re .search (r"\[\s*(?P<start>-?[^:\]] *)\s*:\s*(?P<end>-?[^\]] *)\s*\]" , s )
5454 if not m and ":" in s :
55- m = re .match (r"\s*(?P<start>-?\d *)\s*:\s*(?P<end>-?\d *)\s*$" , s )
55+ m = re .match (r"\s*(?P<start>-?[^:] *)\s*:\s*(?P<end>-?. *)\s*$" , s )
5656 if not m :
5757 return items
58- start_raw = m .group ("start" )
59- end_raw = m .group ("end" )
60- start = int (start_raw ) if start_raw not in (None , "" , "+" ) else None
61- end = int (end_raw ) if end_raw not in (None , "" , "+" ) else None
62- return items [slice (start , end )]
58+ start_raw = (m .group ("start" ) or "" ).strip ()
59+ end_raw = (m .group ("end" ) or "" ).strip ()
60+ total = len (items )
61+
62+ def _parse_index (raw : str ):
63+ if raw in ("" , "+" ):
64+ return None
65+ if raw .endswith ("%" ):
66+ try :
67+ pct = float (raw [:- 1 ].strip ())
68+ except ValueError :
69+ return None
70+ return int (total * pct / 100.0 )
71+ try :
72+ return int (raw )
73+ except ValueError :
74+ try :
75+ frac = float (raw )
76+ except ValueError :
77+ return None
78+ if 0 <= frac <= 1 :
79+ return int (total * frac )
80+ return None
6381
82+ start = _parse_index (start_raw )
83+ end = _parse_index (end_raw )
84+ return items [slice (start , end )]
6485
6586def _map_dtype (dtype_cfg : Any ) -> Any :
6687 if isinstance (dtype_cfg , torch .dtype ):
@@ -333,12 +354,11 @@ def main() -> int:
333354 for item in args .override :
334355 if item is None :
335356 continue
336- for part in str (item ).split ("," ):
337- part = part .strip ()
338- if part :
339- override_items .append (part )
357+ part = str (item ).strip ()
358+ if part :
359+ override_items .append (part )
340360 if override_items :
341- cfg = apply_overrides (cfg , "," . join ( override_items ) )
361+ cfg = apply_overrides (cfg , override_items )
342362 apply_prompt_defaults (cfg )
343363
344364 seed_val = cfg .get ("seed" , None )
@@ -387,24 +407,48 @@ def main() -> int:
387407 train_ds = Dataset .from_list (train_items )
388408 eval_ds = Dataset .from_list (eval_items ) if eval_items else None
389409
390- model_cfg = cfg .get ("model " ) or {}
410+ model_cfg = cfg .get ("agent_model " ) or {}
391411 if not isinstance (model_cfg , dict ):
392412 model_cfg = {}
393- critic_cfg = cfg .get ("critic " ) or {}
394- if not isinstance (critic_cfg , dict ):
395- critic_cfg = {}
413+ critic_model_cfg = cfg .get ("critic_model " ) or {}
414+ if not isinstance (critic_model_cfg , dict ):
415+ critic_model_cfg = {}
396416 model_name = str (model_cfg .get ("name" ) or "" )
397- if not model_name :
398- raise ValueError ("model.name is required" )
417+ agent_names = cfg .get ("agents" )
418+ if not model_name and not agent_names :
419+ raise ValueError ("agent_model.name or agents is required" )
420+ if agent_names is not None :
421+ if not isinstance (agent_names , (list , tuple )) or not all (
422+ isinstance (x , str ) for x in agent_names
423+ ):
424+ raise ValueError ("agents must be a list of model names." )
425+ agent_names = [str (x ) for x in agent_names ]
426+
427+ critic_names = None
428+ critics_field = cfg .get ("critics" )
429+ if critics_field is not None :
430+ if not isinstance (critics_field , (list , tuple )) or not all (
431+ isinstance (x , str ) for x in critics_field
432+ ):
433+ raise ValueError ("critics must be a list of model names." )
434+ critic_names = [str (x ) for x in critics_field ]
399435 model_kwargs : Dict [str , Any ] = {}
400436
401437 dtype = _map_dtype (model_cfg .get ("dtype" ) or model_cfg .get ("torch_dtype" ))
402438 if dtype is not None :
403439 model_kwargs ["torch_dtype" ] = dtype
404440
405- tokenizer = AutoTokenizer .from_pretrained (model_name )
406- if tokenizer .pad_token is None :
407- tokenizer .pad_token = tokenizer .eos_token
441+ tokenizer_source = agent_names [0 ] if agent_names else model_name
442+ if not tokenizer_source :
443+ raise ValueError ("agent_model.name or agents must be provided." )
444+ if agent_names :
445+ tokenizers = [AutoTokenizer .from_pretrained (name ) for name in agent_names ]
446+ else :
447+ tokenizers = [AutoTokenizer .from_pretrained (tokenizer_source )]
448+ for tok in tokenizers :
449+ if tok .pad_token is None :
450+ tok .pad_token = tok .eos_token
451+ tokenizer = tokenizers [0 ]
408452
409453 iac_args = get_iac_args (cfg , model_name = model_name )
410454 formatters = _build_formatters (cfg , num_agents = num_agents , tokenizer = tokenizer )
@@ -594,14 +638,15 @@ def reward_func(
594638 except Exception :
595639 is_multi_turn = False
596640 critic_model_kwargs : Dict [str , Any ] = {}
597- if isinstance (critic_cfg , dict ):
598- critic_dtype = _map_dtype (critic_cfg .get ("dtype" ) or critic_cfg .get ("torch_dtype" ))
641+ if isinstance (critic_model_cfg , dict ):
642+ critic_dtype = _map_dtype (
643+ critic_model_cfg .get ("dtype" ) or critic_model_cfg .get ("torch_dtype" )
644+ )
599645 if critic_dtype is not None :
600646 critic_model_kwargs ["torch_dtype" ] = critic_dtype
601647
602648 trainer_kwargs : Dict [str , Any ] = {
603- "model" : model_name ,
604- "tokenizer" : tokenizer ,
649+ "tokenizer" : tokenizers if agent_names else tokenizer ,
605650 "reward_func" : reward_func ,
606651 "formatters" : formatters ,
607652 "args" : iac_args ,
@@ -615,15 +660,14 @@ def reward_func(
615660 },
616661 "wandb_config" : wandb_config ,
617662 }
618- critics = None
619- if bool (getattr (iac_args , "use_separate_critic" , True )):
620- critic_name = str (critic_cfg .get ("name" ) or "" ).strip ()
621- if not critic_name :
622- raise ValueError ("critic.name must be provided when use_separate_critic is true" )
623- num_agents_val = int (getattr (iac_args , "num_agents" , 1 ))
624- critics = [critic_name ] * num_agents_val
625- if critics is not None :
626- trainer_kwargs ["critics" ] = critics
663+ trainer_kwargs ["agent_model" ] = model_name or None
664+ if agent_names :
665+ trainer_kwargs ["agents" ] = agent_names
666+ critic_name = str (critic_model_cfg .get ("name" ) or "" ).strip () or None
667+ if critic_name :
668+ trainer_kwargs ["critic_model" ] = critic_name
669+ if critic_names :
670+ trainer_kwargs ["critics" ] = critic_names
627671 if reward_processor is not None :
628672 trainer_kwargs ["reward_processor" ] = reward_processor
629673
0 commit comments