Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit f0dd9a4

Browse files
chandlerzuofacebook-github-bot
authored andcommitted
Checkpoint Learning Stats (#78)
Summary: Pull Request resolved: #78 Currently, when resuming from a failed training, learning curve stats history is lost. This diff adds the learning curve stats in the checkpoint file. Reviewed By: lerks Differential Revision: D15977931 fbshipit-source-id: 7031e61f28fa9dc11f9424a67e1447ec64899f01
1 parent 75885ed commit f0dd9a4

5 files changed

Lines changed: 107 additions & 4 deletions

File tree

test/test_functional.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import time
1515
from functools import partial
1616
from tempfile import TemporaryDirectory
17-
from typing import Dict, Iterable, List, NamedTuple, Tuple
17+
from typing import Dict, Iterable, List, Mapping, NamedTuple, Tuple, Union
1818
from unittest import TestCase, main
1919

2020
import attr
@@ -28,6 +28,7 @@
2828
)
2929
from torchbiggraph.eval import do_eval
3030
from torchbiggraph.partitionserver import run_partition_server
31+
from torchbiggraph.stats import SerializedStats
3132
from torchbiggraph.train import train
3233
from torchbiggraph.util import (
3334
call_one_after_the_other,
@@ -225,6 +226,20 @@ def assertIsEmbeddings(
225226
self.assertTrue(np.all(np.isfinite(dataset[...])))
226227
self.assertTrue(np.all(np.linalg.norm(dataset[...], axis=-1) != 0))
227228

229+
def assertIsStatsDict(self, stats: Mapping[str, Union[int, SerializedStats]]) -> None:
230+
self.assertIsInstance(stats, dict)
231+
self.assertIn("index", stats)
232+
for k, v in stats.items():
233+
if k == "index":
234+
self.assertIsInstance(v, int)
235+
else:
236+
self.assertIsInstance(v, dict)
237+
self.assertCountEqual(v.keys(), ["count", "metrics"])
238+
self.assertIsInstance(v["count"], int)
239+
self.assertIsInstance(v["metrics"], dict)
240+
for m in v["metrics"].values():
241+
self.assertIsInstance(m, float)
242+
228243
def assertCheckpointWritten(self, config: ConfigSchema, *, version: int) -> None:
229244
with open(os.path.join(config.checkpoint_path, "checkpoint_version.txt"), "rt") as tf:
230245
self.assertEqual(version, int(tf.read().strip()))
@@ -239,6 +254,10 @@ def assertCheckpointWritten(self, config: ConfigSchema, *, version: int) -> None
239254
self.assertIsModelParameters(hf["model"])
240255
self.assertIsOptimStateDict(hf["optimizer/state_dict"])
241256

257+
with open(os.path.join(config.checkpoint_path, "training_stats.json"), "rt") as tf:
258+
for line in tf:
259+
self.assertIsStatsDict(json.loads(line))
260+
242261
for entity_name, entity in config.entities.items():
243262
for partition in range(entity.num_partitions):
244263
with open(os.path.join(

torchbiggraph/checkpoint_manager.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
import re
1515
from abc import ABC, abstractmethod
1616
from collections import OrderedDict
17-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
17+
from typing import (
18+
Any,
19+
Callable,
20+
Dict,
21+
Generator,
22+
List,
23+
Mapping,
24+
Optional,
25+
Set,
26+
Tuple,
27+
Union,
28+
)
1829

1930
import numpy as np
2031
import torch
@@ -27,6 +38,7 @@
2738
)
2839
from torchbiggraph.config import ConfigSchema
2940
from torchbiggraph.parameter_sharing import ParameterClient
41+
from torchbiggraph.stats import SerializedStats
3042
from torchbiggraph.types import (
3143
ByteTensorType,
3244
EntityName,
@@ -418,6 +430,22 @@ def read_config(self) -> ConfigSchema:
418430
config_json = self.storage.load_config()
419431
return ConfigSchema.from_dict(json.loads(config_json))
420432

433+
def append_stats(
434+
self,
435+
stats: Mapping[str, Union[int, SerializedStats]],
436+
) -> None:
437+
self.storage.append_stats(json.dumps(stats))
438+
439+
def read_stats(self) -> Generator[Dict[str, Union[int, SerializedStats]], None, None]:
440+
for line in self.storage.load_stats():
441+
yield json.loads(line)
442+
443+
def maybe_read_stats(self) -> Generator[Dict[str, Union[int, SerializedStats]], None, None]:
444+
try:
445+
yield from self.read_stats()
446+
except CouldNotLoadData:
447+
pass
448+
421449
def write_new_version(self, config: ConfigSchema) -> None:
422450
if self.background:
423451
self._sync()

torchbiggraph/checkpoint_storage.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
from abc import ABC, abstractmethod
1313
from pathlib import Path
14-
from typing import Any, Dict, NamedTuple, Optional, Tuple
14+
from typing import Any, Dict, Generator, NamedTuple, Optional, Tuple
1515

1616
import h5py
1717
import numpy as np
@@ -116,6 +116,14 @@ def save_config(self, config_json: str) -> None:
116116
def load_config(self) -> str:
117117
pass
118118

119+
@abstractmethod
120+
def append_stats(self, stats_json: str) -> None:
121+
pass
122+
123+
@abstractmethod
124+
def load_stats(self) -> Generator[str, None, None]:
125+
pass
126+
119127
@abstractmethod
120128
def prepare_snapshot(self, version: int, epoch_idx: int) -> None:
121129
pass
@@ -278,6 +286,11 @@ def get_model_file(self, version: int, *, path: Optional[Path] = None) -> Path:
278286
path = self.path
279287
return path / f"model.v{version}.h5"
280288

289+
def get_stats_file(self, *, path: Optional[Path] = None) -> Path:
290+
if path is None:
291+
path = self.path
292+
return path / "training_stats.json"
293+
281294
def get_snapshot_path(self, epoch_idx: int) -> Path:
282295
return self.path / f"epoch_{epoch_idx}"
283296

@@ -417,6 +430,18 @@ def load_config(self) -> str:
417430
except FileNotFoundError as err:
418431
raise CouldNotLoadData() from err
419432

433+
def append_stats(self, stats_json: str) -> None:
434+
with self.get_stats_file().open("at") as tf:
435+
tf.write(f"{stats_json}\n")
436+
437+
def load_stats(self) -> Generator[str, None, None]:
438+
try:
439+
with self.get_stats_file().open("rt") as tf:
440+
for line in tf:
441+
yield line
442+
except FileNotFoundError as err:
443+
raise CouldNotLoadData() from err
444+
420445
def prepare_snapshot(self, version: int, epoch_idx: int) -> None:
421446
self.get_snapshot_path(epoch_idx).mkdir(parents=True, exist_ok=True)
422447

torchbiggraph/stats.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from collections import defaultdict
1010
from statistics import mean
11-
from typing import Iterable, Type
11+
from typing import Dict, Iterable, Type, Union
1212

1313
from torchbiggraph.types import FloatTensorType
1414

@@ -17,6 +17,9 @@ def average_of_sums(*tensors: FloatTensorType) -> float:
1717
return mean(t.sum().item() for t in tensors)
1818

1919

20+
SerializedStats = Dict[str, Union[int, Dict[str, float]]]
21+
22+
2023
class Stats:
2124
"""A class collecting a set of metrics.
2225
@@ -66,3 +69,14 @@ def __eq__(self, other: "Stats") -> bool:
6669
return (isinstance(other, Stats)
6770
and self.count == other.count
6871
and self.metrics == other.metrics)
72+
73+
def to_dict(self) -> SerializedStats:
74+
return {"count": self.count, "metrics": self.metrics}
75+
76+
@classmethod
77+
def from_dict(cls, d: SerializedStats) -> "Stats":
78+
if set(d.keys()) != {"count", "metrics"}:
79+
raise ValueError(
80+
f"Expect keys ['count', 'metrics'] from input but get {list(d.keys())}."
81+
)
82+
return Stats(count=d["count"], **d["metrics"])

torchbiggraph/train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,15 @@ def swap_partitioned_embeddings(
612612

613613
return io_bytes
614614

615+
if rank == RANK_ZERO:
616+
for stats in checkpoint_manager.maybe_read_stats():
617+
yield (
618+
stats["index"],
619+
Stats.from_dict(stats["eval_stats_before"]),
620+
Stats.from_dict(stats["stats"]),
621+
Stats.from_dict(stats["eval_stats_after"]),
622+
)
623+
615624
# Start of the main training loop.
616625
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
617626
logger.info(
@@ -762,6 +771,14 @@ def swap_partitioned_embeddings(
762771
bucket_logger.info(f"Stats after training: {eval_stats_after}")
763772

764773
# Add train/eval metrics to queue
774+
checkpoint_manager.append_stats(
775+
{
776+
"index": current_index,
777+
"eval_stats_before": eval_stats_before.to_dict(),
778+
"stats": stats.to_dict(),
779+
"eval_stats_after": eval_stats_after.to_dict(),
780+
}
781+
)
765782
yield current_index, eval_stats_before, stats, eval_stats_after
766783

767784
swap_partitioned_embeddings(cur_b, None)

0 commit comments

Comments
 (0)