Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Commit 5d43050

Browse files
lwfacebook-github-bot
authored andcommitted
Fix race condition when checkpointing stats
Summary: I recently discovered that the checkpointing of training stats (introduced in D15977931) had a few issues: - it could incur in race conditions (different trainers appending to the file at the same time), which put the file in a corrupted state, crashing all the successive runs trying to resume from the checkpoint (even if we avoided the crashing, we can't trust the data in there). - it could be "ahead" of the rest of the checkpoint, meaning that it could contain stats for training steps whose results weren't checkpointed yet (i.e., we checkpoint only at the end of a pass, but we log stats for each bucket within a stat), thus after resuming we would have duplicate stats in the checkpoint, as some were recomputed without the old ones being purged. There were several solutions to this: - Locking the file (using low-level fnctl calls) before appending to it. Although this is typically supported by filesystems, even distributed ones, it's a bit of an advanced feature, so I was afraid of using it. Also, the rest of our filesystem code avoids race conditions through higher-level logic (assigning writes to different files to different trainers), so it would be a bit dissonnant to solve these race conditions at such a low level. Additionally, it would mean that each new storage needs to reinvent a different low-level way to lock files. - Having each trainer write stats to a different file. This would mean that the checkpoint format depends on the number of trainers, which is something that isn't currently the case. Philosophically speaking, the number of trainers is a detail of the execution, not intrinsic to the data, it shouldn't transpire in the checkpoint. Practically, this means that if we start a run with 10 trainers then resume it using only 9 we would not know that there is one extra stats checkpoint (the last one), and thus we wouldn't load it. - Find a serialization format that produces blobs of the same size for all the stats, and write them to the file at an offset determined by the stats index. This means that each stats gets its own region of the file, disjoint from all other stats, determined "intrinsically". Thus even if two writes occurred at the same time they would not touch the same region of the file. This again seems an ad-hoc solution for this particular storage, and moreover it makes the format rather hard to extend (if we want to add more metrics). - Finally, the solution I went with here, is to collect the stats on a single trainer and have it checkpoint them, ideally at the end of the training pass (at the same time as all the rest of the data). The natural place to do this is the lock server, as stats can be reported by trainers when they release the bucket they are working on. The lock server can then also know which bucket these stats are for (this is an information that currently gets lost). This may in the future allow fancier representations of these stats (immagine, for each pass, a 2d matrix of the loss for each bucket, to see whether one row/column is learning more poorly than others). Reviewed By: adamlerer, chandlerzuo Differential Revision: D17605390 fbshipit-source-id: 9720ddbd0a30624fe101a7146acfacf868a10429
1 parent 53c9ce2 commit 5d43050

5 files changed

Lines changed: 94 additions & 35 deletions

File tree

test/test_functional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,18 @@ def assertIsStatsDict(self, stats: Mapping[str, Union[int, SerializedStats]]) ->
230230
self.assertIsInstance(stats, dict)
231231
self.assertIn("index", stats)
232232
for k, v in stats.items():
233-
if k == "index":
233+
if k in ("epoch_idx", "edge_path_idx", "edge_chunk_idx",
234+
"lhs_partition", "rhs_partition", "index"):
234235
self.assertIsInstance(v, int)
235-
else:
236+
elif k in ("stats", "eval_stats_before", "eval_stats_after"):
236237
self.assertIsInstance(v, dict)
237238
self.assertCountEqual(v.keys(), ["count", "metrics"])
238239
self.assertIsInstance(v["count"], int)
239240
self.assertIsInstance(v["metrics"], dict)
240241
for m in v["metrics"].values():
241242
self.assertIsInstance(m, float)
243+
else:
244+
self.fail(f"Unknown stats key: {k}")
242245

243246
def assertCheckpointWritten(self, config: ConfigSchema, *, version: int) -> None:
244247
with open(os.path.join(config.checkpoint_path, "checkpoint_version.txt"), "rt") as tf:

torchbiggraph/bucket_scheduling.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import logging
1010
import random
1111
from abc import ABC, abstractmethod
12-
from typing import Dict, List, Optional, Set, Tuple
12+
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
1313

1414
from torch_extensions.rpc.rpc import Client, Server
1515

1616
from torchbiggraph.config import BucketOrder
1717
from torchbiggraph.distributed import Startable
18+
from torchbiggraph.stats import Stats
1819
from torchbiggraph.types import Bucket, EntityName, Partition, Rank, Side
1920

2021

@@ -25,6 +26,16 @@
2526
### Bucket scheduling interface.
2627
###
2728

29+
class BucketStats(NamedTuple):
30+
lhs_partition: int
31+
rhs_partition: int
32+
# A global sequence number, tracking the order in which buckets are trained.
33+
index: int
34+
train: Stats
35+
eval_before: Optional[Stats] = None
36+
eval_after: Optional[Stats] = None
37+
38+
2839
class AbstractBucketScheduler(ABC):
2940

3041
@abstractmethod
@@ -36,7 +47,7 @@ def acquire_bucket(self) -> Tuple[Optional[Bucket], int]:
3647
pass
3748

3849
@abstractmethod
39-
def release_bucket(self, bucket: Bucket) -> None:
50+
def release_bucket(self, bucket: Bucket, stats: BucketStats) -> None:
4051
pass
4152

4253
@abstractmethod
@@ -47,6 +58,10 @@ def check_and_set_dirty(self, entity: EntityName, part: Partition) -> bool:
4758
def peek(self) -> Optional[Bucket]:
4859
pass
4960

61+
@abstractmethod
62+
def get_stats_for_pass(self) -> List[BucketStats]:
63+
pass
64+
5065

5166
###
5267
### Implementation for single-machine mode.
@@ -259,6 +274,7 @@ def __init__(self, nparts_lhs: int, nparts_rhs: int, order: BucketOrder) -> None
259274
self.order = order
260275

261276
self.buckets: List[Bucket] = []
277+
self.stats: List[BucketStats] = []
262278

263279
def new_pass(self, is_first: bool) -> None:
264280
self.buckets = create_ordered_buckets(
@@ -267,6 +283,7 @@ def new_pass(self, is_first: bool) -> None:
267283
order=self.order,
268284
generator=random.Random(),
269285
)
286+
self.stats = []
270287

271288
# Print buckets
272289
logger.debug("Partition pairs:")
@@ -282,8 +299,10 @@ def acquire_bucket(self) -> Tuple[Optional[Bucket], int]:
282299
remaining = len(self.buckets)
283300
return bucket, remaining
284301

285-
def release_bucket(self, bucket: Bucket) -> None:
286-
pass
302+
def release_bucket(self, bucket: Bucket, stats: BucketStats) -> None:
303+
if stats.lhs_partition != bucket.lhs or stats.rhs_partition != bucket.rhs:
304+
raise ValueError(f"Bucket and stats don't match: {bucket}, {stats}")
305+
self.stats.append(stats)
287306

288307
def check_and_set_dirty(self, entity: EntityName, part: Partition) -> bool:
289308
return False
@@ -294,6 +313,9 @@ def peek(self) -> Optional[Bucket]:
294313
except IndexError:
295314
return None
296315

316+
def get_stats_for_pass(self) -> List[BucketStats]:
317+
return self.stats.copy()
318+
297319

298320
###
299321
### Implementation for distributed training mode.
@@ -325,13 +347,15 @@ def __init__(
325347
self.active: Dict[Bucket, Rank] = {}
326348
self.done: Set[Bucket] = set()
327349
self.dirty: Set[Tuple[EntityName, Partition]] = set()
350+
self.stats: List[BucketStats] = []
328351
self.initialized_partitions: Optional[Set[Partition]] = None
329352

330353
def new_pass(self, is_first: bool = False) -> None:
331354
"""Start a new epoch of training."""
332355
self.active = {}
333356
self.done = set()
334357
self.dirty = set()
358+
self.stats = []
335359
if self.init_tree and is_first:
336360
self.initialized_partitions = {Partition(0)}
337361
else:
@@ -404,13 +428,15 @@ def acquire_bucket(
404428

405429
return None, remaining
406430

407-
def release_bucket(self, bucket: Bucket) -> None:
431+
def release_bucket(self, bucket: Bucket, stats: BucketStats) -> None:
408432
"""
409433
Releases the lock on lhs and rhs, and marks this pair as done.
410434
"""
411-
if bucket.lhs is not None:
412-
self.active.pop(bucket)
413-
logger.info(f"Bucket {bucket} released: active= {self.active}")
435+
if stats.lhs_partition != bucket.lhs or stats.rhs_partition != bucket.rhs:
436+
raise ValueError(f"Bucket and stats don't match: {bucket}, {stats}")
437+
self.active.pop(bucket)
438+
self.stats.append(stats)
439+
logger.info(f"Bucket {bucket} released: active= {self.active}")
414440

415441
def check_and_set_dirty(self, entity: EntityName, part: Partition) -> bool:
416442
"""
@@ -424,6 +450,9 @@ def check_and_set_dirty(self, entity: EntityName, part: Partition) -> bool:
424450
self.dirty.add(key)
425451
return res
426452

453+
def get_stats_for_pass(self) -> List[BucketStats]:
454+
return sorted(self.stats, key=lambda s: s.index)
455+
427456

428457
class LockClient(Client):
429458

@@ -450,11 +479,14 @@ def acquire_bucket(self) -> Tuple[Optional[Bucket], int]:
450479
self.old_b = bucket
451480
return bucket, remaining
452481

453-
def release_bucket(self, bucket: Bucket) -> None:
454-
self.client.release_bucket(bucket)
482+
def release_bucket(self, bucket: Bucket, stats: BucketStats) -> None:
483+
self.client.release_bucket(bucket, stats)
455484

456485
def check_and_set_dirty(self, entity: EntityName, part: Partition) -> bool:
457486
return self.client.check_and_set_dirty(entity, part)
458487

459488
def peek(self) -> Optional[Bucket]:
460489
return None
490+
491+
def get_stats_for_pass(self) -> List[BucketStats]:
492+
return self.client.get_stats_for_pass()

torchbiggraph/checkpoint_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
Dict,
2121
Generator,
2222
List,
23-
Mapping,
2423
Optional,
2524
Set,
2625
Tuple,
@@ -432,9 +431,9 @@ def read_config(self) -> ConfigSchema:
432431

433432
def append_stats(
434433
self,
435-
stats: Mapping[str, Union[int, SerializedStats]],
434+
stats: List[Dict[str, Union[int, SerializedStats]]],
436435
) -> None:
437-
self.storage.append_stats(json.dumps(stats))
436+
self.storage.append_stats([json.dumps(s) for s in stats])
438437

439438
def read_stats(self) -> Generator[Dict[str, Union[int, SerializedStats]], None, None]:
440439
for line in self.storage.load_stats():

torchbiggraph/checkpoint_storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
from abc import ABC, abstractmethod
1313
from pathlib import Path
14-
from typing import Any, Dict, Generator, NamedTuple, Optional, Tuple
14+
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple
1515

1616
import h5py
1717
import numpy as np
@@ -117,7 +117,7 @@ def load_config(self) -> str:
117117
pass
118118

119119
@abstractmethod
120-
def append_stats(self, stats_json: str) -> None:
120+
def append_stats(self, stats_json: List[str]) -> None:
121121
pass
122122

123123
@abstractmethod
@@ -430,9 +430,9 @@ def load_config(self) -> str:
430430
except FileNotFoundError as err:
431431
raise CouldNotLoadData() from err
432432

433-
def append_stats(self, stats_json: str) -> None:
433+
def append_stats(self, stats_json: List[str]) -> None:
434434
with self.get_stats_file().open("at") as tf:
435-
tf.write(f"{stats_json}\n")
435+
tf.write("".join(f"{s}\n" for s in stats_json))
436436

437437
def load_stats(self) -> Generator[str, None, None]:
438438
try:

torchbiggraph/train.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from torchbiggraph.bucket_scheduling import (
2828
AbstractBucketScheduler,
29+
BucketStats,
2930
DistributedBucketScheduler,
3031
LockServer,
3132
SingleMachineBucketScheduler,
@@ -557,6 +558,7 @@ def load_embeddings(
557558
def swap_partitioned_embeddings(
558559
old_b: Optional[Bucket],
559560
new_b: Optional[Bucket],
561+
old_stats: Optional[BucketStats],
560562
):
561563
# 0. given the old and new buckets, construct data structures to keep
562564
# track of old and new embedding (entity, part) tuples
@@ -577,6 +579,8 @@ def swap_partitioned_embeddings(
577579
# 1. checkpoint embeddings that will not be used in the next pair
578580
#
579581
if old_b is not None: # there are previous embeddings to checkpoint
582+
if old_stats is None:
583+
raise TypeError("Got old bucket but not its stats")
580584
logger.info("Writing partitioned embeddings")
581585
for entity, part in to_checkpoint:
582586
side = old_parts[(entity, part)]
@@ -593,7 +597,7 @@ def swap_partitioned_embeddings(
593597
del embs
594598
del optim_state
595599

596-
bucket_scheduler.release_bucket(old_b)
600+
bucket_scheduler.release_bucket(old_b, old_stats)
597601

598602
# 2. copy old embeddings that will be used in the next pair
599603
# into a temporary dictionary
@@ -669,19 +673,22 @@ def swap_partitioned_embeddings(
669673
sync.barrier()
670674

671675
remaining = total_buckets
672-
cur_b = None
676+
cur_b: Optional[Bucket] = None
677+
cur_stats: Optional[BucketStats] = None
673678
while remaining > 0:
674-
old_b = cur_b
679+
old_b: Optional[Bucket] = cur_b
680+
old_stats: Optional[BucketStats] = cur_stats
675681
io_time = 0.
676682
io_bytes = 0
677683
cur_b, remaining = bucket_scheduler.acquire_bucket()
678684
logger.info(f"still in queue: {remaining}")
679685
if cur_b is None:
686+
cur_stats = None
680687
if old_b is not None:
681688
# if you couldn't get a new pair, release the lock
682689
# to prevent a deadlock!
683690
tic = time.time()
684-
io_bytes += swap_partitioned_embeddings(old_b, None)
691+
io_bytes += swap_partitioned_embeddings(old_b, None, old_stats)
685692
io_time += time.time() - tic
686693
time.sleep(1) # don't hammer td
687694
continue
@@ -690,7 +697,7 @@ def swap_partitioned_embeddings(
690697

691698
tic = time.time()
692699

693-
io_bytes += swap_partitioned_embeddings(old_b, cur_b)
700+
io_bytes += swap_partitioned_embeddings(old_b, cur_b, old_stats)
694701

695702
current_index = \
696703
(iteration_manager.iteration_idx + 1) * total_buckets - remaining
@@ -803,19 +810,18 @@ def swap_partitioned_embeddings(
803810
eval_stats_after = Stats.sum(all_eval_stats_after).average()
804811
bucket_logger.info(f"Stats after training: {eval_stats_after}")
805812

806-
# Add train/eval metrics to queue
807-
stats_dict = {
808-
"index": current_index,
809-
"stats": stats.to_dict(),
810-
}
811-
if eval_stats_before is not None:
812-
stats_dict["eval_stats_before"] = eval_stats_before.to_dict()
813-
if eval_stats_after is not None:
814-
stats_dict["eval_stats_after"] = eval_stats_after.to_dict()
815-
checkpoint_manager.append_stats(stats_dict)
816813
yield current_index, eval_stats_before, stats, eval_stats_after
817814

818-
swap_partitioned_embeddings(cur_b, None)
815+
cur_stats = BucketStats(
816+
lhs_partition=cur_b.lhs,
817+
rhs_partition=cur_b.rhs,
818+
index=current_index,
819+
train=stats,
820+
eval_before=eval_stats_before,
821+
eval_after=eval_stats_after,
822+
)
823+
824+
swap_partitioned_embeddings(cur_b, None, cur_stats)
819825

820826
# Distributed Processing: all machines can leave the barrier now.
821827
sync.barrier()
@@ -858,6 +864,25 @@ def swap_partitioned_embeddings(
858864
OptimizerStateDict(trainer.global_optimizer.state_dict()),
859865
)
860866

867+
logger.info("Writing the training stats")
868+
all_stats_dicts: List[Dict[...]] = []
869+
for stats in bucket_scheduler.get_stats_for_pass():
870+
stats_dict = {
871+
"epoch_idx": epoch_idx,
872+
"edge_path_idx": edge_path_idx,
873+
"edge_chunk_idx": edge_chunk_idx,
874+
"lhs_partition": stats.lhs_partition,
875+
"rhs_partition": stats.rhs_partition,
876+
"index": stats.index,
877+
"stats": stats.train.to_dict(),
878+
}
879+
if stats.eval_before is not None:
880+
stats_dict["eval_stats_before"] = stats.eval_before.to_dict()
881+
if stats.eval_after is not None:
882+
stats_dict["eval_stats_after"] = stats.eval_after.to_dict()
883+
all_stats_dicts.append(stats_dict)
884+
checkpoint_manager.append_stats(all_stats_dicts)
885+
861886
logger.info("Writing the checkpoint")
862887
checkpoint_manager.write_new_version(config)
863888

0 commit comments

Comments
 (0)