Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 23e3ff2

Browse files
lwfacebook-github-bot
authored andcommitted
Fix stats checkpointing when no eval-during-train
Summary: Bug introduced in D15977931. Fixes #104. Reviewed By: chandlerzuo Differential Revision: D17498076 fbshipit-source-id: f9a8368c8e7217cd8ed1cd6d22f20e568de2d529
1 parent cb7830b commit 23e3ff2

1 file changed

Lines changed: 19 additions & 15 deletions

File tree

torchbiggraph/train.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -613,13 +613,16 @@ def swap_partitioned_embeddings(
613613
return io_bytes
614614

615615
if rank == RANK_ZERO:
616-
for stats in checkpoint_manager.maybe_read_stats():
617-
yield (
618-
stats["index"],
619-
Stats.from_dict(stats["eval_stats_before"]),
620-
Stats.from_dict(stats["stats"]),
621-
Stats.from_dict(stats["eval_stats_after"]),
622-
)
616+
for stats_dict in checkpoint_manager.maybe_read_stats():
617+
index: int = stats_dict["index"]
618+
stats: Stats = Stats.from_dict(stats_dict["stats"])
619+
eval_stats_before: Optional[Stats] = None
620+
if "eval_stats_before" in stats_dict:
621+
eval_stats_before = Stats.from_dict(stats_dict["eval_stats_before"])
622+
eval_stats_after: Optional[Stats] = None
623+
if "eval_stats_after" in stats_dict:
624+
eval_stats_after = Stats.from_dict(stats_dict["eval_stats_after"])
625+
yield (index, eval_stats_before, stats, eval_stats_after)
623626

624627
# Start of the main training loop.
625628
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
@@ -771,14 +774,15 @@ def swap_partitioned_embeddings(
771774
bucket_logger.info(f"Stats after training: {eval_stats_after}")
772775

773776
# Add train/eval metrics to queue
774-
checkpoint_manager.append_stats(
775-
{
776-
"index": current_index,
777-
"eval_stats_before": eval_stats_before.to_dict(),
778-
"stats": stats.to_dict(),
779-
"eval_stats_after": eval_stats_after.to_dict(),
780-
}
781-
)
777+
stats_dict = {
778+
"index": current_index,
779+
"stats": stats.to_dict(),
780+
}
781+
if eval_stats_before is not None:
782+
stats_dict["eval_stats_before"] = eval_stats_before.to_dict()
783+
if eval_stats_after is not None:
784+
stats_dict["eval_stats_after"] = eval_stats_after.to_dict()
785+
checkpoint_manager.append_stats(stats_dict)
782786
yield current_index, eval_stats_before, stats, eval_stats_after
783787

784788
swap_partitioned_embeddings(cur_b, None)

0 commit comments

Comments
 (0)