@@ -613,13 +613,16 @@ def swap_partitioned_embeddings(
613613 return io_bytes
614614
615615 if rank == RANK_ZERO :
616- for stats in checkpoint_manager .maybe_read_stats ():
617- yield (
618- stats ["index" ],
619- Stats .from_dict (stats ["eval_stats_before" ]),
620- Stats .from_dict (stats ["stats" ]),
621- Stats .from_dict (stats ["eval_stats_after" ]),
622- )
616+ for stats_dict in checkpoint_manager .maybe_read_stats ():
617+ index : int = stats_dict ["index" ]
618+ stats : Stats = Stats .from_dict (stats_dict ["stats" ])
619+ eval_stats_before : Optional [Stats ] = None
620+ if "eval_stats_before" in stats_dict :
621+ eval_stats_before = Stats .from_dict (stats_dict ["eval_stats_before" ])
622+ eval_stats_after : Optional [Stats ] = None
623+ if "eval_stats_after" in stats_dict :
624+ eval_stats_after = Stats .from_dict (stats_dict ["eval_stats_after" ])
625+ yield (index , eval_stats_before , stats , eval_stats_after )
623626
624627 # Start of the main training loop.
625628 for epoch_idx , edge_path_idx , edge_chunk_idx in iteration_manager :
@@ -771,14 +774,15 @@ def swap_partitioned_embeddings(
771774 bucket_logger .info (f"Stats after training: { eval_stats_after } " )
772775
773776 # Add train/eval metrics to queue
774- checkpoint_manager .append_stats (
775- {
776- "index" : current_index ,
777- "eval_stats_before" : eval_stats_before .to_dict (),
778- "stats" : stats .to_dict (),
779- "eval_stats_after" : eval_stats_after .to_dict (),
780- }
781- )
777+ stats_dict = {
778+ "index" : current_index ,
779+ "stats" : stats .to_dict (),
780+ }
781+ if eval_stats_before is not None :
782+ stats_dict ["eval_stats_before" ] = eval_stats_before .to_dict ()
783+ if eval_stats_after is not None :
784+ stats_dict ["eval_stats_after" ] = eval_stats_after .to_dict ()
785+ checkpoint_manager .append_stats (stats_dict )
782786 yield current_index , eval_stats_before , stats , eval_stats_after
783787
784788 swap_partitioned_embeddings (cur_b , None )
0 commit comments